root commited on
Commit
5e0b9df
·
0 Parent(s):

initial commit

Browse files
Files changed (48) hide show
  1. .DS_Store +0 -0
  2. .gitignore +135 -0
  3. LICENSE +201 -0
  4. NOTICE +39 -0
  5. README.md +130 -0
  6. configs/hico_train.sh +40 -0
  7. configs/vcoco_train.sh +42 -0
  8. hico_20160224_det +1 -0
  9. hotr/data/datasets/__init__.py +24 -0
  10. hotr/data/datasets/builtin_meta.py +110 -0
  11. hotr/data/datasets/coco.py +156 -0
  12. hotr/data/datasets/hico.py +243 -0
  13. hotr/data/datasets/vcoco.py +467 -0
  14. hotr/data/evaluators/coco_eval.py +256 -0
  15. hotr/data/evaluators/hico_eval.py +242 -0
  16. hotr/data/evaluators/vcoco_eval.py +57 -0
  17. hotr/data/transforms/transforms.py +387 -0
  18. hotr/engine/__init__.py +14 -0
  19. hotr/engine/arg_parser.py +163 -0
  20. hotr/engine/evaluator_coco.py +62 -0
  21. hotr/engine/evaluator_hico.py +55 -0
  22. hotr/engine/evaluator_vcoco.py +87 -0
  23. hotr/engine/trainer.py +73 -0
  24. hotr/metrics/utils.py +90 -0
  25. hotr/metrics/vcoco/ap_agent.py +104 -0
  26. hotr/metrics/vcoco/ap_role.py +193 -0
  27. hotr/models/__init__.py +5 -0
  28. hotr/models/backbone.py +118 -0
  29. hotr/models/criterion.py +349 -0
  30. hotr/models/detr.py +187 -0
  31. hotr/models/detr_matcher.py +81 -0
  32. hotr/models/feed_forward.py +16 -0
  33. hotr/models/hotr.py +241 -0
  34. hotr/models/hotr_matcher.py +216 -0
  35. hotr/models/position_encoding.py +89 -0
  36. hotr/models/post_process.py +162 -0
  37. hotr/models/transformer.py +320 -0
  38. hotr/util/__init__.py +0 -0
  39. hotr/util/box_ops.py +110 -0
  40. hotr/util/logger.py +145 -0
  41. hotr/util/misc.py +401 -0
  42. hotr/util/ramp.py +23 -0
  43. imgs/mainfig.png +0 -0
  44. main.py +240 -0
  45. tools/launch.py +192 -0
  46. tools/run_dist_launch.sh +29 -0
  47. tools/run_dist_slurm.sh +33 -0
  48. v-coco +1 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitignore ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+ wandb/
131
+ checkpoints/
132
+
133
+ # old version
134
+ hotr/models/hotr_v1.py
135
+ Makefile
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "{}"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2021 KAKAO BRAIN Corp. All Rights Reserved.
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
NOTICE ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ===============================================================================
2
+ DETR' Apache License 2.0
3
+ ===============================================================================
4
+ The implementation code is based on the implementation in DETR
5
+ (https://github.com/facebookresearch/detr).
6
+ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
7
+ Copyright (c) 2020 Facebook
8
+
9
+ Licensed under the Apache License, Version 2.0 (the "License");
10
+ you may not use this file except in compliance with the License.
11
+ You may obtain a copy of the License at
12
+
13
+ http://www.apache.org/licenses/LICENSE-2.0
14
+
15
+ Unless required by applicable law or agreed to in writing, software
16
+ distributed under the License is distributed on an "AS IS" BASIS,
17
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ See the License for the specific language governing permissions and
19
+ limitations under the License.
20
+
21
+ ===============================================================================
22
+ QPIC' Apache License 2.0
23
+ ===============================================================================
24
+ The implementation code is based on the implementation in QPIC
25
+ (https://github.com/hitachi-rd-cv/qpic).
26
+ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
27
+ Copyright (c) 2021 Hitachi
28
+
29
+ Licensed under the Apache License, Version 2.0 (the "License");
30
+ you may not use this file except in compliance with the License.
31
+ You may obtain a copy of the License at
32
+
33
+ http://www.apache.org/licenses/LICENSE-2.0
34
+
35
+ Unless required by applicable law or agreed to in writing, software
36
+ distributed under the License is distributed on an "AS IS" BASIS,
37
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
38
+ See the License for the specific language governing permissions and
39
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CPC_HOTR
2
+
3
+ This repository contains the application of [Cross-Path Consistency Learning](https://arxiv.org/abs/2204.04836) at [HOTR](https://arxiv.org/abs/2104.13682), based on the official implementation of QPIC in [here](https://github.com/kakaobrain/HOTR).
4
+
5
+ <div align="center">
6
+ <img src=".github/mainfig.png" width="900px" />
7
+ </div>
8
+
9
+
10
+ ## 1. Environmental Setup
11
+ ```bash
12
+ $ conda create -n HOTR_CPC python=3.7
13
+ $ conda install -c pytorch pytorch torchvision # PyTorch 1.7.1, torchvision 0.8.2, CUDA=11.0
14
+ $ conda install cython scipy
15
+ $ pip install pycocotools
16
+ $ pip install opencv-python
17
+ $ pip install wandb
18
+ ```
19
+
20
+ ## 2. HOI dataset setup
21
+ Our current version of HOTR supports the experiments for both [V-COCO](https://github.com/s-gupta/v-coco) and [HICO-DET](https://drive.google.com/file/d/1QZcJmGVlF9f4h-XLWe9Gkmnmj2z1gSnk/view) dataset.
22
+ Download the dataset under the pulled directory.
23
+ For HICO-DET, we use the [annotation files](https://drive.google.com/file/d/1QZcJmGVlF9f4h-XLWe9Gkmnmj2z1gSnk/view) provided by the PPDM authors.
24
+ Download the [list of actions](https://drive.google.com/open?id=1EeHNHuYyJI-qqDk_-5nay7Mb07tzZLsl) as `list_action.txt` and place them under the unballed hico-det directory.
25
+ Below we present how you should place the files.
26
+ ```bash
27
+ # V-COCO setup
28
+ $ git clone https://github.com/s-gupta/v-coco.git
29
+ $ cd v-coco
30
+ $ ln -s [:COCO_DIR] coco/images # COCO_DIR contains images of train2014 & val2014
31
+ $ python script_pick_annotations.py [:COCO_DIR]/annotations
32
+
33
+ # HICO-DET setup
34
+ $ tar -zxvf hico_20160224_det.tar.gz # move the unballed folder under the pulled repository
35
+
36
+ # dataset setup
37
+ HOTR
38
+ │─ v-coco
39
+ │ │─ data
40
+ │ │ │─ instances_vcoco_all_2014.json
41
+ │ │ :
42
+ │ └─ coco
43
+ │ │─ images
44
+ │ │ │─ train2014
45
+ │ │ │ │─ COCO_train2014_000000000009.jpg
46
+ │ │ │ :
47
+ │ │ └─ val2014
48
+ │ │ │─ COCO_val2014_000000000042.jpg
49
+ : : :
50
+ │─ hico_20160224_det
51
+ │ │─ list_action.txt
52
+ │ │─ annotations
53
+ │ │ │─ trainval_hico.json
54
+ │ │ │─ test_hico.json
55
+ │ │ └─ corre_hico.npy
56
+ : :
57
+ ```
58
+
59
+ If you wish to download the datasets on our own directory, simply change the 'data_path' argument to the directory you have downloaded the datasets.
60
+ ```bash
61
+ --data_path [:your_own_directory]/[v-coco/hico_20160224_det]
62
+ ```
63
+
64
+ ## 3. Training
65
+ After the preparation, you can start the training with the following command.
66
+
67
+ For the HICO-DET training.
68
+ ```
69
+ GPUS_PER_NODE=8 ./tools/run_dist_launch.sh 8 ./configs/hico_train.sh
70
+ ```
71
+ For the V-COCO training.
72
+ ```
73
+ GPUS_PER_NODE=8 ./tools/run_dist_launch.sh 8 ./configs/vcoco_train.sh
74
+ ```
75
+
76
+ ## 4. Evaluation
77
+ For evaluation of main inference path P1 (x->HOI), `--path_id` should be set to 0.
78
+ Indexes of Augmented paths are range to 1~3. (1: x->HO->I, 2: x->HI->O, 3: x->OI->H)
79
+
80
+ HICODET
81
+ ```
82
+ python -m torch.distributed.launch \
83
+ --nproc_per_node=8 \
84
+ --use_env main.py \
85
+ --batch_size 2 \
86
+ --HOIDet \
87
+ --path_id 0 \
88
+ --share_enc \
89
+ --pretrained_dec \
90
+ --share_dec_param \
91
+ --num_hoi_queries [:query_num] \
92
+ --object_threshold 0 \
93
+ --temperature 0.2 \ # use the exact same temperature value that you used during training!
94
+ --no_aux_loss \
95
+ --eval \
96
+ --dataset_file hico-det \
97
+ --data_path hico_20160224_det \
98
+ --resume checkpoints/hico_det/hico_[:query_num].pth
99
+ ```
100
+
101
+ VCOCO
102
+ ```
103
+ python -m torch.distributed.launch \
104
+ --nproc_per_node=8 \
105
+ --use_env vcoco_main.py \
106
+ --batch_size 2 \
107
+ --HOIDet \
108
+ --path_id 0 \
109
+ --share_enc \
110
+ --share_dec_param \
111
+ --pretrained_dec \
112
+ --num_hoi_queries [:query_num] \
113
+ --temperature 0.05 \ # use the exact same temperature value that you used during training!
114
+ --object_threshold 0 \
115
+ --no_aux_loss \
116
+ --eval \
117
+ --dataset_file vcoco \
118
+ --data_path v-coco \
119
+ --resume checkpoints/vcoco/vcoco_[:query_num].pth
120
+ ```
121
+
122
+ ## Citation
123
+ ```
124
+ @inproceedings{park2022consistency,
125
+ title={Consistency Learning via Decoding Path Augmentation for Transformers in Human Object Interaction Detection},
126
+ author={Park, Jihwan and Lee, SeungJun and Heo, Hwan and Choi, Hyeong Kyu and Kim, Hyunwoo J},
127
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
128
+ year={2022}
129
+ }
130
+ ```
configs/hico_train.sh ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ set -x
4
+
5
+ EXP_DIR=logs_run_001
6
+ PY_ARGS=${@:1}
7
+
8
+ python -u main.py \
9
+ --project_name CPC_HOTR_HICODET \
10
+ --run_name ${EXP_DIR} \
11
+ --HOIDet \
12
+ --validate \
13
+ --share_enc \
14
+ --pretrained_dec \
15
+ --use_consis \
16
+ --share_dec_param \
17
+ --epochs 90 \
18
+ --lr_drop 60 \
19
+ --lr 1e-4 \
20
+ --lr_backbone 1e-5 \
21
+ --ramp_up_epoch 30 \
22
+ --path_id 0 \
23
+ --num_hoi_queries 16 \
24
+ --set_cost_idx 20 \
25
+ --hoi_idx_loss_coef 1 \
26
+ --hoi_act_loss_coef 10 \
27
+ --backbone resnet50 \
28
+ --hoi_consistency_loss_coef 0.2 \
29
+ --hoi_idx_consistency_loss_coef 1 \
30
+ --hoi_act_consistency_loss_coef 2 \
31
+ --hoi_eos_coef 0.1 \
32
+ --temperature 0.2 \
33
+ --no_aux_loss \
34
+ --hoi_aux_loss \
35
+ --dataset_file hico-det \
36
+ --data_path hico_20160224_det \
37
+ --frozen_weights https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth \
38
+ --output_dir checkpoints/hico_det/ \
39
+ --augpath_name [\'p2\',\'p3\',\'p4\'] \
40
+ ${PY_ARGS}
configs/vcoco_train.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ set -x
4
+
5
+ EXP_DIR=logs_run_001
6
+ PY_ARGS=${@:1}
7
+
8
+ python -u main.py \
9
+ --project_name CPC_HOTR_VCOCO \
10
+ --run_name ${EXP_DIR} \
11
+ --HOIDet \
12
+ --validate \
13
+ --share_enc \
14
+ --pretrained_dec \
15
+ --use_consis \
16
+ --share_dec_param \
17
+ --epochs 90 \
18
+ --lr_drop 60 \
19
+ --lr 1e-4 \
20
+ --lr_backbone 1e-5 \
21
+ --ramp_up_epoch 30 \
22
+ --path_id 0 \
23
+ --num_hoi_queries 16 \
24
+ --set_cost_idx 10 \
25
+ --hoi_idx_loss_coef 1 \
26
+ --hoi_act_loss_coef 10 \
27
+ --backbone resnet50 \
28
+ --hoi_consistency_loss_coef 1 \
29
+ --hoi_idx_consistency_loss_coef 1 \
30
+ --hoi_act_consistency_loss_coef 10 \
31
+ --stop_grad_stage \
32
+ --hoi_eos_coef 0.1 \
33
+ --temperature 0.05 \
34
+ --no_aux_loss \
35
+ --hoi_aux_loss \
36
+ --dataset_file vcoco \
37
+ --data_path v-coco \
38
+ --frozen_weights https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth \
39
+ --output_dir checkpoints/vcoco/ \
40
+ --augpath_name [\'p2\',\'p3\',\'p4\'] \
41
+ ${PY_ARGS}
42
+
hico_20160224_det ADDED
@@ -0,0 +1 @@
 
 
1
+ /data/public/rw/datasets/hico_20160224_det/
hotr/data/datasets/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ import torch.utils.data
3
+ import torchvision
4
+
5
+ from hotr.data.datasets.coco import build as build_coco
6
+ from hotr.data.datasets.vcoco import build as build_vcoco
7
+ from hotr.data.datasets.hico import build as build_hico
8
+
9
+ def get_coco_api_from_dataset(dataset):
10
+ for _ in range(10): # what is this for?
11
+ if isinstance(dataset, torch.utils.data.Subset):
12
+ dataset = dataset.dataset
13
+ if isinstance(dataset, torchvision.datasets.CocoDetection):
14
+ return dataset.coco
15
+
16
+
17
+ def build_dataset(image_set, args):
18
+ if args.dataset_file == 'coco':
19
+ return build_coco(image_set, args)
20
+ elif args.dataset_file == 'vcoco':
21
+ return build_vcoco(image_set, args)
22
+ elif args.dataset_file == 'hico-det':
23
+ return build_hico(image_set, args)
24
+ raise ValueError(f'dataset {args.dataset_file} not supported')
hotr/data/datasets/builtin_meta.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ COCO_CATEGORIES = [
2
+ {"color": [], "isthing": 0, "id": 0, "name": "N/A"},
3
+ {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
4
+ {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"},
5
+ {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"},
6
+ {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"},
7
+ {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"},
8
+ {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"},
9
+ {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"},
10
+ {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"},
11
+ {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"},
12
+ {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"},
13
+ {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"},
14
+ {"color": [], "isthing": 0, "id": 12, "name": "N/A"},
15
+ {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"},
16
+ {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"},
17
+ {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"},
18
+ {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"},
19
+ {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"},
20
+ {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"},
21
+ {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"},
22
+ {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"},
23
+ {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"},
24
+ {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"},
25
+ {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"},
26
+ {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"},
27
+ {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"},
28
+ {"color": [], "isthing": 0, "id": 26, "name": "N/A"},
29
+ {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"},
30
+ {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"},
31
+ {"color": [], "isthing": 0, "id": 29, "name": "N/A"},
32
+ {"color": [], "isthing": 0, "id": 30, "name": "N/A"},
33
+ {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"},
34
+ {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"},
35
+ {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"},
36
+ {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"},
37
+ {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"},
38
+ {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"},
39
+ {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"},
40
+ {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"},
41
+ {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"},
42
+ {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"},
43
+ {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"},
44
+ {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"},
45
+ {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"},
46
+ {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"},
47
+ {"color": [], "isthing": 0, "id": 45, "name": "N/A"},
48
+ {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"},
49
+ {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"},
50
+ {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"},
51
+ {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"},
52
+ {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"},
53
+ {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"},
54
+ {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"},
55
+ {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"},
56
+ {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"},
57
+ {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"},
58
+ {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"},
59
+ {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"},
60
+ {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"},
61
+ {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"},
62
+ {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"},
63
+ {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"},
64
+ {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"},
65
+ {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"},
66
+ {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"},
67
+ {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"},
68
+ {"color": [], "isthing": 0, "id": 66, "name": "N/A"},
69
+ {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"},
70
+ {"color": [], "isthing": 0, "id": 68, "name": "N/A"},
71
+ {"color": [], "isthing": 0, "id": 69, "name": "N/A"},
72
+ {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"},
73
+ {"color": [], "isthing": 0, "id": 71, "name": "N/A"},
74
+ {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"},
75
+ {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"},
76
+ {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"},
77
+ {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"},
78
+ {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"},
79
+ {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"},
80
+ {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"},
81
+ {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"},
82
+ {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"},
83
+ {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"},
84
+ {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"},
85
+ {"color": [], "isthing": 0, "id": 83, "name": "N/A"},
86
+ {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"},
87
+ {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"},
88
+ {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"},
89
+ {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"},
90
+ {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"},
91
+ {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"},
92
+ {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"},
93
+ ]
94
+
95
+ def _get_coco_instances_meta():
96
+ thing_ids = [k["id"] for k in COCO_CATEGORIES if k["isthing"] == 1]
97
+ assert len(thing_ids) == 80, f"Length of thing ids : {len(thing_ids)}"
98
+
99
+ thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
100
+ thing_classes = [k["name"] for k in COCO_CATEGORIES if k["isthing"] == 1]
101
+ thing_colors = [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 1]
102
+
103
+ coco_classes = [k["name"] for k in COCO_CATEGORIES]
104
+
105
+ return {
106
+ "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
107
+ "thing_classes": thing_classes,
108
+ "thing_colors": thing_colors,
109
+ "coco_classes": coco_classes,
110
+ }
hotr/data/datasets/coco.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ COCO dataset which returns image_id for evaluation.
4
+ Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
5
+ """
6
+ from pathlib import Path
7
+
8
+ import torch
9
+ import torch.utils.data
10
+ import torchvision
11
+ from pycocotools import mask as coco_mask
12
+
13
+ import hotr.data.transforms.transforms as T
14
+
15
+ class CocoDetection(torchvision.datasets.CocoDetection):
16
+ def __init__(self, img_folder, ann_file, transforms, return_masks):
17
+ super(CocoDetection, self).__init__(img_folder, ann_file)
18
+ self._transforms = transforms
19
+ self.prepare = ConvertCocoPolysToMask(return_masks)
20
+
21
+ def __getitem__(self, idx):
22
+ img, target = super(CocoDetection, self).__getitem__(idx)
23
+ image_id = self.ids[idx]
24
+ target = {'image_id': image_id, 'annotations': target}
25
+ img, target = self.prepare(img, target)
26
+ if self._transforms is not None:
27
+ img, target = self._transforms(img, target)
28
+ return img, target
29
+
30
+
31
+ def convert_coco_poly_to_mask(segmentations, height, width):
32
+ masks = []
33
+ for polygons in segmentations:
34
+ rles = coco_mask.frPyObjects(polygons, height, width)
35
+ mask = coco_mask.decode(rles)
36
+ if len(mask.shape) < 3:
37
+ mask = mask[..., None]
38
+ mask = torch.as_tensor(mask, dtype=torch.uint8)
39
+ mask = mask.any(dim=2)
40
+ masks.append(mask)
41
+ if masks:
42
+ masks = torch.stack(masks, dim=0)
43
+ else:
44
+ masks = torch.zeros((0, height, width), dtype=torch.uint8)
45
+ return masks
46
+
47
+
48
+ class ConvertCocoPolysToMask(object):
49
+ def __init__(self, return_masks=False):
50
+ self.return_masks = return_masks
51
+
52
+ def __call__(self, image, target):
53
+ w, h = image.size
54
+
55
+ image_id = target["image_id"]
56
+ image_id = torch.tensor([image_id])
57
+
58
+ anno = target["annotations"]
59
+
60
+ anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]
61
+
62
+ boxes = [obj["bbox"] for obj in anno]
63
+ # guard against no boxes via resizing
64
+ boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
65
+ boxes[:, 2:] += boxes[:, :2] # (x1, y1, w, h) -> (x1, y1, x2, y2)
66
+ boxes[:, 0::2].clamp_(min=0, max=w)
67
+ boxes[:, 1::2].clamp_(min=0, max=h)
68
+
69
+ classes = [obj["category_id"] for obj in anno]
70
+ classes = torch.tensor(classes, dtype=torch.int64)
71
+
72
+ if self.return_masks:
73
+ segmentations = [obj["segmentation"] for obj in anno]
74
+ masks = convert_coco_poly_to_mask(segmentations, h, w)
75
+
76
+ keypoints = None
77
+ if anno and "keypoints" in anno[0]:
78
+ keypoints = [obj["keypoints"] for obj in anno]
79
+ keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
80
+ num_keypoints = keypoints.shape[0]
81
+ if num_keypoints:
82
+ keypoints = keypoints.view(num_keypoints, -1, 3)
83
+
84
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
85
+ boxes = boxes[keep]
86
+ classes = classes[keep]
87
+ if self.return_masks:
88
+ masks = masks[keep]
89
+ if keypoints is not None:
90
+ keypoints = keypoints[keep]
91
+
92
+ target = {}
93
+ target["boxes"] = boxes
94
+ target["labels"] = classes
95
+ if self.return_masks:
96
+ target["masks"] = masks
97
+ target["image_id"] = image_id
98
+ if keypoints is not None:
99
+ target["keypoints"] = keypoints
100
+
101
+ # for conversion to coco api
102
+ area = torch.tensor([obj["area"] for obj in anno])
103
+ iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
104
+ target["area"] = area[keep]
105
+ target["iscrowd"] = iscrowd[keep]
106
+
107
+ target["orig_size"] = torch.as_tensor([int(h), int(w)])
108
+ target["size"] = torch.as_tensor([int(h), int(w)])
109
+
110
+ return image, target
111
+
112
+
113
+ def make_coco_transforms(image_set):
114
+
115
+ normalize = T.Compose([
116
+ T.ToTensor(),
117
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
118
+ ])
119
+
120
+ scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
121
+
122
+ if image_set == 'train':
123
+ return T.Compose([
124
+ T.RandomHorizontalFlip(),
125
+ T.RandomSelect(
126
+ T.RandomResize(scales, max_size=1333),
127
+ T.Compose([
128
+ T.RandomResize([400, 500, 600]),
129
+ T.RandomSizeCrop(384, 600),
130
+ T.RandomResize(scales, max_size=1333),
131
+ ])
132
+ ),
133
+ normalize,
134
+ ])
135
+
136
+ if image_set == 'val':
137
+ return T.Compose([
138
+ T.RandomResize([800], max_size=1333),
139
+ normalize,
140
+ ])
141
+
142
+ raise ValueError(f'unknown {image_set}')
143
+
144
+
145
+ def build(image_set, args):
146
+ root = Path(args.data_path)
147
+ assert root.exists(), f'provided COCO path {root} does not exist'
148
+ mode = 'instances'
149
+ PATHS = {
150
+ "train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'),
151
+ "val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'),
152
+ }
153
+
154
+ img_folder, ann_file = PATHS[image_set]
155
+ dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks)
156
+ return dataset
hotr/data/datasets/hico.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # HOTR official code : hotr/data/datasets/hico.py
3
+ # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4
+ # ------------------------------------------------------------------------
5
+ # Modified from QPIC (https://github.com/hitachi-rd-cv/qpic)
6
+ # Copyright (c) Hitachi, Ltd. All Rights Reserved.
7
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
8
+ # ------------------------------------------------------------------------
9
+ from pathlib import Path
10
+ from PIL import Image
11
+ import json
12
+ from collections import defaultdict
13
+ import numpy as np
14
+
15
+ import torch
16
+ import torch.utils.data
17
+ import torchvision
18
+
19
+ from hotr.data.datasets import builtin_meta
20
+ import hotr.data.transforms.transforms as T
21
+
22
+
23
+ class HICODetection(torch.utils.data.Dataset):
24
+ def __init__(self, img_set, img_folder, anno_file, action_list_file, transforms, num_queries):
25
+ self.img_set = img_set
26
+ self.img_folder = img_folder
27
+ with open(anno_file, 'r') as f:
28
+ self.annotations = json.load(f)
29
+ with open(action_list_file, 'r') as f:
30
+ self.action_lines = f.readlines()
31
+ self._transforms = transforms
32
+ self.num_queries = num_queries
33
+ self.get_metadata()
34
+
35
+ if img_set == 'train':
36
+ self.ids = []
37
+ for idx, img_anno in enumerate(self.annotations):
38
+ for hoi in img_anno['hoi_annotation']:
39
+ if hoi['subject_id'] >= len(img_anno['annotations']) or hoi['object_id'] >= len(img_anno['annotations']):
40
+ break
41
+ else:
42
+ self.ids.append(idx)
43
+ else:
44
+ self.ids = list(range(len(self.annotations)))
45
+
46
+ ############################################################################
47
+ # Number Method
48
+ ############################################################################
49
+ def get_metadata(self):
50
+ meta = builtin_meta._get_coco_instances_meta()
51
+ self.COCO_CLASSES = meta['coco_classes']
52
+ self._valid_obj_ids = [id for id in meta['thing_dataset_id_to_contiguous_id'].keys()]
53
+ self._valid_verb_ids, self._valid_verb_names = [], []
54
+ for action_line in self.action_lines[2:]:
55
+ act_id, act_name = action_line.split()
56
+ self._valid_verb_ids.append(int(act_id))
57
+ self._valid_verb_names.append(act_name)
58
+
59
+ def get_valid_obj_ids(self):
60
+ return self._valid_obj_ids
61
+
62
+ def get_actions(self):
63
+ return self._valid_verb_names
64
+
65
+ def num_category(self):
66
+ return len(self.COCO_CLASSES)
67
+
68
+ def num_action(self):
69
+ return len(self._valid_verb_ids)
70
+ ############################################################################
71
+
72
+ def __len__(self):
73
+ return len(self.ids)
74
+
75
+ def __getitem__(self, idx):
76
+ img_anno = self.annotations[self.ids[idx]]
77
+
78
+ img = Image.open(self.img_folder / img_anno['file_name']).convert('RGB')
79
+ w, h = img.size
80
+
81
+ # cut out the GTs that exceed the number of object queries
82
+ if self.img_set == 'train' and len(img_anno['annotations']) > self.num_queries:
83
+ img_anno['annotations'] = img_anno['annotations'][:self.num_queries]
84
+
85
+ boxes = [obj['bbox'] for obj in img_anno['annotations']]
86
+ # guard against no boxes via resizing
87
+ boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
88
+
89
+ if self.img_set == 'train':
90
+ # Add index for confirming which boxes are kept after image transformation
91
+ classes = [(i, self._valid_obj_ids.index(obj['category_id'])) for i, obj in enumerate(img_anno['annotations'])]
92
+ else:
93
+ classes = [self._valid_obj_ids.index(obj['category_id']) for obj in img_anno['annotations']]
94
+ classes = torch.tensor(classes, dtype=torch.int64)
95
+
96
+ target = {}
97
+ target['orig_size'] = torch.as_tensor([int(h), int(w)])
98
+ target['size'] = torch.as_tensor([int(h), int(w)])
99
+ if self.img_set == 'train':
100
+ boxes[:, 0::2].clamp_(min=0, max=w)
101
+ boxes[:, 1::2].clamp_(min=0, max=h)
102
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
103
+ boxes = boxes[keep]
104
+ classes = classes[keep]
105
+
106
+ target['boxes'] = boxes
107
+ target['labels'] = classes
108
+ target['iscrowd'] = torch.tensor([0 for _ in range(boxes.shape[0])])
109
+ target['area'] = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
110
+
111
+ if self._transforms is not None:
112
+ img, target = self._transforms(img, target)
113
+
114
+ kept_box_indices = [label[0] for label in target['labels']]
115
+
116
+ target['labels'] = target['labels'][:, 1]
117
+
118
+ obj_labels, verb_labels, sub_boxes, obj_boxes = [], [], [], []
119
+ sub_obj_pairs = []
120
+ for hoi in img_anno['hoi_annotation']:
121
+ if hoi['subject_id'] not in kept_box_indices or hoi['object_id'] not in kept_box_indices:
122
+ continue
123
+ sub_obj_pair = (hoi['subject_id'], hoi['object_id'])
124
+ if sub_obj_pair in sub_obj_pairs:
125
+ verb_labels[sub_obj_pairs.index(sub_obj_pair)][self._valid_verb_ids.index(hoi['category_id'])] = 1
126
+ else:
127
+ sub_obj_pairs.append(sub_obj_pair)
128
+ obj_labels.append(target['labels'][kept_box_indices.index(hoi['object_id'])])
129
+ verb_label = [0 for _ in range(len(self._valid_verb_ids))]
130
+ verb_label[self._valid_verb_ids.index(hoi['category_id'])] = 1
131
+ sub_box = target['boxes'][kept_box_indices.index(hoi['subject_id'])]
132
+ obj_box = target['boxes'][kept_box_indices.index(hoi['object_id'])]
133
+ verb_labels.append(verb_label)
134
+ sub_boxes.append(sub_box)
135
+ obj_boxes.append(obj_box)
136
+ if len(sub_obj_pairs) == 0:
137
+ target['pair_targets'] = torch.zeros((0,), dtype=torch.int64)
138
+ target['pair_actions'] = torch.zeros((0, len(self._valid_verb_ids)), dtype=torch.float32)
139
+ target['sub_boxes'] = torch.zeros((0, 4), dtype=torch.float32)
140
+ target['obj_boxes'] = torch.zeros((0, 4), dtype=torch.float32)
141
+ else:
142
+ target['pair_targets'] = torch.stack(obj_labels)
143
+ target['pair_actions'] = torch.as_tensor(verb_labels, dtype=torch.float32)
144
+ target['sub_boxes'] = torch.stack(sub_boxes)
145
+ target['obj_boxes'] = torch.stack(obj_boxes)
146
+ else:
147
+ target['boxes'] = boxes
148
+ target['labels'] = classes
149
+ target['id'] = idx
150
+
151
+ if self._transforms is not None:
152
+ img, _ = self._transforms(img, None)
153
+
154
+ hois = []
155
+ for hoi in img_anno['hoi_annotation']:
156
+ hois.append((hoi['subject_id'], hoi['object_id'], self._valid_verb_ids.index(hoi['category_id'])))
157
+ target['hois'] = torch.as_tensor(hois, dtype=torch.int64)
158
+
159
+ return img, target
160
+
161
+ def set_rare_hois(self, anno_file):
162
+ with open(anno_file, 'r') as f:
163
+ annotations = json.load(f)
164
+
165
+ counts = defaultdict(lambda: 0)
166
+ for img_anno in annotations:
167
+ hois = img_anno['hoi_annotation']
168
+ bboxes = img_anno['annotations']
169
+ for hoi in hois:
170
+ triplet = (self._valid_obj_ids.index(bboxes[hoi['subject_id']]['category_id']),
171
+ self._valid_obj_ids.index(bboxes[hoi['object_id']]['category_id']),
172
+ self._valid_verb_ids.index(hoi['category_id']))
173
+ counts[triplet] += 1
174
+ self.rare_triplets = []
175
+ self.non_rare_triplets = []
176
+ for triplet, count in counts.items():
177
+ if count < 10:
178
+ self.rare_triplets.append(triplet)
179
+ else:
180
+ self.non_rare_triplets.append(triplet)
181
+
182
+ def load_correct_mat(self, path):
183
+ self.correct_mat = np.load(path)
184
+
185
+
186
+ # Add color jitter to coco transforms
187
+ def make_hico_transforms(image_set):
188
+
189
+ normalize = T.Compose([
190
+ T.ToTensor(),
191
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
192
+ ])
193
+
194
+ scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
195
+
196
+ if image_set == 'train':
197
+ return T.Compose([
198
+ T.RandomHorizontalFlip(),
199
+ T.ColorJitter(.4, .4, .4),
200
+ T.RandomSelect(
201
+ T.RandomResize(scales, max_size=1333),
202
+ T.Compose([
203
+ T.RandomResize([400, 500, 600]),
204
+ T.RandomSizeCrop(384, 600),
205
+ T.RandomResize(scales, max_size=1333),
206
+ ])
207
+ ),
208
+ normalize,
209
+ ])
210
+
211
+ if image_set == 'val':
212
+ return T.Compose([
213
+ T.RandomResize([800], max_size=1333),
214
+ normalize,
215
+ ])
216
+
217
+ if image_set == 'test':
218
+ return T.Compose([
219
+ T.RandomResize([800], max_size=1333),
220
+ normalize,
221
+ ])
222
+
223
+ raise ValueError(f'unknown {image_set}')
224
+
225
+
226
+ def build(image_set, args):
227
+ root = Path(args.data_path)
228
+ assert root.exists(), f'provided HOI path {root} does not exist'
229
+ PATHS = {
230
+ 'train': (root / 'images' / 'train2015', root / 'annotations' / 'trainval_hico.json'),
231
+ 'val': (root / 'images' / 'test2015', root / 'annotations' / 'test_hico.json'),
232
+ 'test': (root / 'images' / 'test2015', root / 'annotations' / 'test_hico.json')
233
+ }
234
+ CORRECT_MAT_PATH = root / 'annotations' / 'corre_hico.npy'
235
+ action_list_file = root / 'list_action.txt'
236
+
237
+ img_folder, anno_file = PATHS[image_set]
238
+ dataset = HICODetection(image_set, img_folder, anno_file, action_list_file, transforms=make_hico_transforms(image_set),
239
+ num_queries=args.num_queries)
240
+ if image_set == 'val' or image_set == 'test':
241
+ dataset.set_rare_hois(PATHS['train'][1])
242
+ dataset.load_correct_mat(CORRECT_MAT_PATH)
243
+ return dataset
hotr/data/datasets/vcoco.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kakaobrain, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ V-COCO dataset which returns image_id for evaluation.
4
+ """
5
+ from pathlib import Path
6
+
7
+ from PIL import Image
8
+ import os
9
+ import numpy as np
10
+ import json
11
+ import torch
12
+ import torch.utils.data
13
+ import torchvision
14
+
15
+ from torch.utils.data import Dataset
16
+ from pycocotools.coco import COCO
17
+ from pycocotools import mask as coco_mask
18
+
19
+ from hotr.data.datasets import builtin_meta
20
+ import hotr.data.transforms.transforms as T
21
+
22
+ class VCocoDetection(Dataset):
23
+ def __init__(self,
24
+ img_folder,
25
+ ann_file,
26
+ all_file,
27
+ filter_empty_gt=True,
28
+ transforms=None):
29
+ self.img_folder = img_folder
30
+ self.file_meta = dict()
31
+ self._transforms = transforms
32
+
33
+ self.ann_file = ann_file
34
+ self.all_file = all_file
35
+ self.filter_empty_gt = filter_empty_gt
36
+
37
+ # COCO initialize
38
+ self.coco = COCO(self.all_file)
39
+ self.COCO_CLASSES = builtin_meta._get_coco_instances_meta()['coco_classes']
40
+ self.file_meta['coco_classes'] = self.COCO_CLASSES
41
+
42
+ # Load V-COCO Dataset
43
+ self.vcoco_all = self.load_vcoco(self.ann_file)
44
+
45
+ # Save COCO annotation data
46
+ self.image_ids = sorted(list(set(self.vcoco_all[0]['image_id'].reshape(-1))))
47
+
48
+ # Filter Data
49
+ if filter_empty_gt:
50
+ self.filter_image_id()
51
+ self.img_infos = self.load_annotations()
52
+
53
+ # Refine Data
54
+ self.save_action_name()
55
+ self.mapping_inst_action_to_action()
56
+ self.load_subobj_classes()
57
+ self.CLASSES = self.act_list
58
+
59
+ ############################################################################
60
+ # Load V-COCO Dataset
61
+ ############################################################################
62
+ def load_vcoco(self, dir_name=None):
63
+ with open(dir_name, 'rt') as f:
64
+ vsrl_data = json.load(f)
65
+
66
+ for i in range(len(vsrl_data)):
67
+ vsrl_data[i]['role_object_id'] = np.array(vsrl_data[i]['role_object_id']).reshape((len(vsrl_data[i]['role_name']),-1)).T
68
+ for j in ['ann_id', 'label', 'image_id']:
69
+ vsrl_data[i][j] = np.array(vsrl_data[i][j]).reshape((-1,1))
70
+
71
+ return vsrl_data
72
+
73
+ ############################################################################
74
+ # Refine Data
75
+ ############################################################################
76
+ def save_action_name(self):
77
+ self.inst_act_list = list()
78
+ self.act_list = list()
79
+
80
+ # add instance action human classes
81
+ self.num_subject_act = 0
82
+ for vcoco in self.vcoco_all:
83
+ self.inst_act_list.append('human_' + vcoco['action_name'])
84
+ self.num_subject_act += 1
85
+
86
+ # add instance action object classes
87
+ for vcoco in self.vcoco_all:
88
+ if len(vcoco['role_name']) == 3:
89
+ self.inst_act_list.append('object_' + vcoco['action_name']+'_'+vcoco['role_name'][1])
90
+ self.inst_act_list.append('object_' + vcoco['action_name']+'_'+vcoco['role_name'][2])
91
+ elif len(vcoco['role_name']) < 2:
92
+ continue
93
+ else:
94
+ self.inst_act_list.append('object_' + vcoco['action_name']+'_'+vcoco['role_name'][-1]) # when only two roles
95
+
96
+ # add action classes
97
+ for vcoco in self.vcoco_all:
98
+ if len(vcoco['role_name']) == 3:
99
+ self.act_list.append(vcoco['action_name']+'_'+vcoco['role_name'][1])
100
+ self.act_list.append(vcoco['action_name']+'_'+vcoco['role_name'][2])
101
+ else:
102
+ self.act_list.append(vcoco['action_name']+'_'+vcoco['role_name'][-1])
103
+
104
+ # add to meta
105
+ self.file_meta['action_classes'] = self.act_list
106
+
107
+ def mapping_inst_action_to_action(self):
108
+ sub_idx = 0
109
+ obj_idx = self.num_subject_act
110
+
111
+ self.sub_label_to_action = list()
112
+ self.obj_label_to_action = list()
113
+
114
+ for vcoco in self.vcoco_all:
115
+ role_name = vcoco['role_name']
116
+
117
+ self.sub_label_to_action.append(sub_idx)
118
+ if len(role_name) == 3 :
119
+ self.sub_label_to_action.append(sub_idx)
120
+ self.obj_label_to_action.append(obj_idx)
121
+ self.obj_label_to_action.append(obj_idx+1)
122
+ obj_idx += 2
123
+ elif len(role_name) == 2:
124
+ self.obj_label_to_action.append(obj_idx)
125
+ obj_idx += 1
126
+ else:
127
+ self.obj_label_to_action.append(0)
128
+
129
+ sub_idx += 1
130
+
131
+ def load_subobj_classes(self):
132
+ self.vcoco_labels = dict()
133
+ for img in self.image_ids:
134
+ self.vcoco_labels[img] = dict()
135
+ self.vcoco_labels[img]['boxes'] = np.empty((0, 4), dtype=np.float32)
136
+ self.vcoco_labels[img]['categories'] = np.empty((0), dtype=np.int32)
137
+
138
+ ann_ids = self.coco.getAnnIds(imgIds=img, iscrowd=None)
139
+ objs = self.coco.loadAnns(ann_ids)
140
+
141
+ valid_ann_ids = []
142
+
143
+ for i, obj in enumerate(objs):
144
+ if 'ignore' in obj and obj['ignore'] == 1: continue
145
+
146
+ x1 = obj['bbox'][0]
147
+ y1 = obj['bbox'][1]
148
+ x2 = x1 + np.maximum(0., obj['bbox'][2] - 1.)
149
+ y2 = y1 + np.maximum(0., obj['bbox'][3] - 1.)
150
+
151
+ if obj['area'] > 0 and x2 > x1 and y2 > y1:
152
+ bbox = np.array([x1, y1, x2, y2]).reshape(1, -1)
153
+ cls = obj['category_id']
154
+ self.vcoco_labels[img]['boxes'] = np.concatenate([self.vcoco_labels[img]['boxes'], bbox], axis=0)
155
+ self.vcoco_labels[img]['categories'] = np.concatenate([self.vcoco_labels[img]['categories'], [cls]], axis=0)
156
+
157
+ valid_ann_ids.append(ann_ids[i])
158
+
159
+ num_valid_objs = len(valid_ann_ids)
160
+
161
+ self.vcoco_labels[img]['agent_actions'] = -np.ones((num_valid_objs, self.num_action()), dtype=np.int32)
162
+ self.vcoco_labels[img]['obj_actions'] = np.zeros((num_valid_objs, self.num_action()), dtype=np.int32)
163
+ self.vcoco_labels[img]['role_id'] = -np.ones((num_valid_objs, self.num_action()), dtype=np.int32)
164
+
165
+ for ix, ann_id in enumerate(valid_ann_ids):
166
+ in_vcoco = np.where(self.vcoco_all[0]['ann_id'] == ann_id)[0]
167
+ if in_vcoco.size > 0:
168
+ self.vcoco_labels[img]['agent_actions'][ix, :] = 0
169
+
170
+ agent_act_id = 0
171
+ obj_act_id = -1
172
+ for i, x in enumerate(self.vcoco_all):
173
+ has_label = np.where(np.logical_and(x['ann_id'] == ann_id, x['label'] == 1))[0]
174
+ if has_label.size > 0:
175
+ assert has_label.size == 1
176
+ rids = x['role_object_id'][has_label]
177
+
178
+ if rids.shape[1] == 3:
179
+ self.vcoco_labels[img]['agent_actions'][ix, agent_act_id] = 1
180
+ self.vcoco_labels[img]['agent_actions'][ix, agent_act_id+1] = 1
181
+ agent_act_id += 2
182
+ else:
183
+ self.vcoco_labels[img]['agent_actions'][ix, agent_act_id] = 1
184
+ agent_act_id += 1
185
+ if rids.shape[1] == 1 : obj_act_id += 1
186
+
187
+ for j in range(1, rids.shape[1]):
188
+ obj_act_id += 1
189
+ if rids[0, j] == 0: continue # no role
190
+ aid = np.where(valid_ann_ids == rids[0, j])[0]
191
+
192
+ self.vcoco_labels[img]['role_id'][ix, obj_act_id] = aid
193
+ self.vcoco_labels[img]['obj_actions'][aid, obj_act_id] = 1
194
+
195
+ else:
196
+ rids = x['role_object_id'][0]
197
+ if rids.shape[0] == 3:
198
+ agent_act_id += 2
199
+ obj_act_id += 2
200
+ else:
201
+ agent_act_id += 1
202
+ obj_act_id += 1
203
+
204
+ ############################################################################
205
+ # Annotation Loader
206
+ ############################################################################
207
+ # >>> 1. instance
208
+ def load_instance_annotations(self, image_index):
209
+ num_ann = self.vcoco_labels[image_index]['boxes'].shape[0]
210
+ inst_action = np.zeros((num_ann, self.num_inst_action()), np.int)
211
+ inst_bbox = np.zeros((num_ann, 4), dtype=np.float32)
212
+ inst_category = np.zeros((num_ann, ), dtype=np.int)
213
+
214
+ for idx in range(num_ann):
215
+ inst_bbox[idx] = self.vcoco_labels[image_index]['boxes'][idx]
216
+ inst_category[idx]= self.vcoco_labels[image_index]['categories'][idx] #+ 1 # category 1 ~ 81
217
+
218
+ if inst_category[idx] == 1:
219
+ act = self.vcoco_labels[image_index]['agent_actions'][idx]
220
+ inst_action[idx, :self.num_subject_act] = act[np.unique(self.sub_label_to_action, return_index=True)[1]]
221
+
222
+ # when person is the obj
223
+ act = self.vcoco_labels[image_index]['obj_actions'][idx] # when person is the obj
224
+ if act.any():
225
+ inst_action[idx, self.num_subject_act:] = act[np.nonzero(self.obj_label_to_action)[0]]
226
+ if inst_action[idx, :self.num_subject_act].sum(axis=-1) < 0:
227
+ inst_action[idx, :self.num_subject_act] = 0
228
+ else:
229
+ act = self.vcoco_labels[image_index]['obj_actions'][idx]
230
+ inst_action[idx, self.num_subject_act:] = act[np.nonzero(self.obj_label_to_action)[0]]
231
+
232
+ # >>> For Objects that are in COCO but not in V-COCO,
233
+ # >>> Human -> [-1 * 26, 0 * 25]
234
+ # >>> Object -> [0 * 51]
235
+ # >>> Don't return anything for actions with max 0 or max -1
236
+ max_val = inst_action.max(axis=1)
237
+ if (max_val > 0).sum() == 0:
238
+ print(f"No Annotations for {image_index}")
239
+ print(inst_action)
240
+ print(self.vcoco_labels[image_index]['agent_actions'][idx])
241
+ print(self.vcoco_labels[image_index]['obj_actions'][idx])
242
+
243
+ return inst_bbox[max_val > 0], inst_category[max_val > 0], inst_action[max_val > 0]
244
+
245
+ # >>> 2. pair
246
+ def load_pair_annotations(self, image_index):
247
+ num_ann = self.vcoco_labels[image_index]['boxes'].shape[0]
248
+ pair_action = np.zeros((0, self.num_action()), np.int)
249
+ pair_bbox = np.zeros((0, 8), dtype=np.float32)
250
+ pair_target = np.zeros((0, ), dtype=np.int)
251
+
252
+ for idx in range(num_ann):
253
+ h_box = self.vcoco_labels[image_index]['boxes'][idx]
254
+ h_cat = self.vcoco_labels[image_index]['categories'][idx]
255
+ if h_cat != 1 : continue # human_id = 1
256
+
257
+ h_act = self.vcoco_labels[image_index]['agent_actions'][idx]
258
+ if np.any((h_act==-1)) : continue
259
+
260
+ o_act = dict()
261
+ for aid in range(self.num_action()):
262
+ if h_act[aid] == 0 : continue
263
+ o_id = self.vcoco_labels[image_index]['role_id'][idx, aid]
264
+ if o_id not in o_act : o_act[o_id] = list()
265
+ o_act[o_id].append(aid)
266
+
267
+ for o_id in o_act.keys():
268
+ if o_id == -1:
269
+ o_box = -np.ones((4, ))
270
+ o_cat = -1 # target is background
271
+ else:
272
+ o_box = self.vcoco_labels[image_index]['boxes'][o_id]
273
+ o_cat = self.vcoco_labels[image_index]['categories'][o_id] # category 0 ~ 80
274
+
275
+ box = np.concatenate([h_box, o_box]).astype(np.float32)
276
+ act = np.zeros((1, self.num_action()), np.int)
277
+ tar = np.zeros((1, ), np.int)
278
+ tar[0] = o_cat #+ 1 # category 1 ~ 81
279
+ for o_aid in o_act[o_id] : act[0, o_aid] = 1
280
+
281
+ pair_action = np.concatenate([pair_action, act], axis=0)
282
+ pair_bbox = np.concatenate([pair_bbox, np.expand_dims(box, axis=0)], axis=0)
283
+ pair_target = np.concatenate([pair_target, tar], axis=0)
284
+
285
+ return pair_bbox, pair_action, pair_target
286
+
287
+ # >>> 3. image infos
288
+ def load_annotations(self):
289
+ img_infos = []
290
+ for i in self.image_ids:
291
+ info = self.coco.loadImgs([i])[0]
292
+ img_infos.append(info)
293
+ return img_infos
294
+
295
+ ############################################################################
296
+ # Check Method
297
+ ############################################################################
298
+ def sum_action_ann_for_id(self, find_idx):
299
+ sum = 0
300
+ for action_ann in self.vcoco_all:
301
+ img_ids = action_ann['image_id']
302
+ img_labels = action_ann['label']
303
+
304
+ final_inds = img_ids[img_labels == 1]
305
+
306
+ if (find_idx in final_inds):
307
+ sum += 1
308
+ # sum of class-wise existence
309
+ return (sum > 0)
310
+
311
+ def filter_image_id(self):
312
+ empty_gt_list = []
313
+ for img_id in self.image_ids:
314
+ if not self.sum_action_ann_for_id(img_id):
315
+ empty_gt_list.append(img_id)
316
+
317
+ for remove_id in empty_gt_list:
318
+ rm_idx = self.image_ids.index(remove_id)
319
+ self.image_ids.remove(remove_id)
320
+
321
+ ############################################################################
322
+ # Preprocessing
323
+ ############################################################################
324
+ def prepare_img(self, idx):
325
+ img_info = self.img_infos[idx]
326
+ image = Image.open(os.path.join(self.img_folder, img_info['file_name'])).convert('RGB')
327
+ target = self.get_ann_info(idx)
328
+
329
+ w, h = image.size
330
+ target["orig_size"] = torch.as_tensor([int(h), int(w)])
331
+ target["size"] = torch.as_tensor([int(h), int(w)])
332
+
333
+ if self._transforms is not None:
334
+ img, target = self._transforms(image, target) # "size" gets converted here
335
+
336
+ return img, target
337
+
338
+ ############################################################################
339
+ # Get Method
340
+ ############################################################################
341
+ def __getitem__(self, idx):
342
+ img, target = self.prepare_img(idx)
343
+ return img, target
344
+
345
+ def __len__(self):
346
+ return len(self.image_ids)
347
+
348
+ def get_human_label_idx(self):
349
+ return self.sub_label_to_action
350
+
351
+ def get_object_label_idx(self):
352
+ return self.obj_label_to_action
353
+
354
+ def get_image_ids(self):
355
+ return self.image_ids
356
+
357
+ def get_categories(self):
358
+ return self.COCO_CLASSES
359
+
360
+ def get_inst_action(self):
361
+ return self.inst_act_list
362
+
363
+ def get_actions(self):
364
+ return self.act_list
365
+
366
+ def get_human_action(self):
367
+ return self.inst_act_list[:self.num_subject_act]
368
+
369
+ def get_object_action(self):
370
+ return self.inst_act_list[self.num_subject_act:]
371
+
372
+ def get_ann_info(self, idx):
373
+ img_idx = int(self.image_ids[idx])
374
+
375
+ # load each annotation
376
+ inst_bbox, inst_label, inst_actions = self.load_instance_annotations(img_idx)
377
+ pair_bbox, pair_actions, pair_targets = self.load_pair_annotations(img_idx)
378
+
379
+ sample = {
380
+ 'image_id' : torch.tensor([img_idx]),
381
+ 'boxes': torch.as_tensor(inst_bbox, dtype=torch.float32),
382
+ 'labels': torch.tensor(inst_label, dtype=torch.int64),
383
+ 'inst_actions': torch.tensor(inst_actions, dtype=torch.int64),
384
+ 'pair_boxes': torch.as_tensor(pair_bbox, dtype=torch.float32),
385
+ 'pair_actions': torch.tensor(pair_actions, dtype=torch.int64),
386
+ 'pair_targets': torch.tensor(pair_targets, dtype=torch.int64),
387
+ }
388
+
389
+ return sample
390
+
391
+ ############################################################################
392
+ # Number Method
393
+ ############################################################################
394
+ def num_category(self):
395
+ return len(self.COCO_CLASSES)
396
+
397
+ def num_action(self):
398
+ return len(self.act_list)
399
+
400
+ def num_inst_action(self):
401
+ return len(self.inst_act_list)
402
+
403
+ def num_human_act(self):
404
+ return len(self.inst_act_list[:self.num_subject_act])
405
+
406
+ def num_object_act(self):
407
+ return len(self.inst_act_list[self.num_subject_act:])
408
+
409
+ def make_hoi_transforms(image_set):
410
+ normalize = T.Compose([
411
+ T.ToTensor(),
412
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
413
+ ])
414
+
415
+ scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
416
+
417
+ if image_set == 'train':
418
+ return T.Compose([
419
+ T.RandomHorizontalFlip(),
420
+ T.ColorJitter(.4, .4, .4),
421
+ T.RandomSelect(
422
+ T.RandomResize(scales, max_size=1333),
423
+ T.Compose([
424
+ T.RandomResize([400, 500, 600]),
425
+ T.RandomSizeCrop(384, 600),
426
+ T.RandomResize(scales, max_size=1333),
427
+ ])
428
+ ),
429
+ normalize,
430
+ ])
431
+
432
+ if image_set == 'val':
433
+ return T.Compose([
434
+ T.RandomResize([800], max_size=1333),
435
+ normalize,
436
+ ])
437
+
438
+ if image_set == 'test':
439
+ return T.Compose([
440
+ T.RandomResize([800], max_size=1333),
441
+ normalize,
442
+ ])
443
+
444
+ raise ValueError(f'unknown {image_set}')
445
+
446
+ def build(image_set, args):
447
+ root = Path(args.data_path)
448
+ assert root.exists(), f'provided V-COCO path {root} does not exist'
449
+ PATHS = {
450
+ "train": (root / "coco/images/train2014/", root / "data/vcoco" / 'vcoco_trainval.json'),
451
+ "val": (root / "coco/images/val2014", root / "data/vcoco" / 'vcoco_test.json'),
452
+ "test": (root / "coco/images/val2014", root / "data/vcoco" / 'vcoco_test.json'),
453
+ }
454
+
455
+ img_folder, ann_file = PATHS[image_set]
456
+ all_file = root / "data/instances_vcoco_all_2014.json"
457
+ dataset = VCocoDetection(
458
+ img_folder = img_folder,
459
+ ann_file = ann_file,
460
+ all_file = all_file,
461
+ filter_empty_gt=True,
462
+ transforms = make_hoi_transforms(image_set)
463
+ )
464
+ dataset.file_meta['dataset_file'] = args.dataset_file
465
+ dataset.file_meta['image_set'] = image_set
466
+
467
+ return dataset
hotr/data/evaluators/coco_eval.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ COCO evaluator that works in distributed mode.
4
+ Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
5
+ The difference is that there is less copy-pasting from pycocotools
6
+ in the end of the file, as python3 can suppress prints with contextlib
7
+ """
8
+ import os
9
+ import contextlib
10
+ import copy
11
+ import numpy as np
12
+ import torch
13
+
14
+ from pycocotools.cocoeval import COCOeval
15
+ from pycocotools.coco import COCO
16
+ import pycocotools.mask as mask_util
17
+
18
+ from hotr.util.misc import all_gather
19
+
20
+
21
+ class CocoEvaluator(object):
22
+ def __init__(self, coco_gt, iou_types):
23
+ assert isinstance(iou_types, (list, tuple))
24
+ coco_gt = copy.deepcopy(coco_gt)
25
+ self.coco_gt = coco_gt
26
+
27
+ self.iou_types = iou_types
28
+ self.coco_eval = {}
29
+ for iou_type in iou_types:
30
+ self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
31
+
32
+ self.img_ids = []
33
+ self.eval_imgs = {k: [] for k in iou_types}
34
+
35
+ def update(self, predictions):
36
+ img_ids = list(np.unique(list(predictions.keys())))
37
+ self.img_ids.extend(img_ids)
38
+
39
+ for iou_type in self.iou_types:
40
+ results = self.prepare(predictions, iou_type)
41
+
42
+ # suppress pycocotools prints
43
+ with open(os.devnull, 'w') as devnull:
44
+ with contextlib.redirect_stdout(devnull):
45
+ coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
46
+ coco_eval = self.coco_eval[iou_type]
47
+
48
+ coco_eval.cocoDt = coco_dt
49
+ coco_eval.params.imgIds = list(img_ids)
50
+ img_ids, eval_imgs = evaluate(coco_eval)
51
+
52
+ self.eval_imgs[iou_type].append(eval_imgs)
53
+
54
+ def synchronize_between_processes(self):
55
+ for iou_type in self.iou_types:
56
+ self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
57
+ create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
58
+
59
+ def accumulate(self):
60
+ for coco_eval in self.coco_eval.values():
61
+ coco_eval.accumulate()
62
+
63
+ def summarize(self):
64
+ for iou_type, coco_eval in self.coco_eval.items():
65
+ print("IoU metric: {}".format(iou_type))
66
+ coco_eval.summarize()
67
+
68
+ def prepare(self, predictions, iou_type):
69
+ if iou_type == "bbox":
70
+ return self.prepare_for_coco_detection(predictions)
71
+ elif iou_type == "segm":
72
+ return self.prepare_for_coco_segmentation(predictions)
73
+ elif iou_type == "keypoints":
74
+ return self.prepare_for_coco_keypoint(predictions)
75
+ else:
76
+ raise ValueError("Unknown iou type {}".format(iou_type))
77
+
78
+ def prepare_for_coco_detection(self, predictions):
79
+ coco_results = []
80
+ for original_id, prediction in predictions.items():
81
+ if len(prediction) == 0:
82
+ continue
83
+
84
+ boxes = prediction["boxes"]
85
+ boxes = convert_to_xywh(boxes).tolist()
86
+ scores = prediction["scores"].tolist()
87
+ labels = prediction["labels"].tolist()
88
+
89
+ coco_results.extend(
90
+ [
91
+ {
92
+ "image_id": original_id,
93
+ "category_id": labels[k],
94
+ "bbox": box,
95
+ "score": scores[k],
96
+ }
97
+ for k, box in enumerate(boxes)
98
+ ]
99
+ )
100
+ return coco_results
101
+
102
+ def prepare_for_coco_segmentation(self, predictions):
103
+ coco_results = []
104
+ for original_id, prediction in predictions.items():
105
+ if len(prediction) == 0:
106
+ continue
107
+
108
+ scores = prediction["scores"]
109
+ labels = prediction["labels"]
110
+ masks = prediction["masks"]
111
+
112
+ masks = masks > 0.5
113
+
114
+ scores = prediction["scores"].tolist()
115
+ labels = prediction["labels"].tolist()
116
+
117
+ rles = [
118
+ mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
119
+ for mask in masks
120
+ ]
121
+ for rle in rles:
122
+ rle["counts"] = rle["counts"].decode("utf-8")
123
+
124
+ coco_results.extend(
125
+ [
126
+ {
127
+ "image_id": original_id,
128
+ "category_id": labels[k],
129
+ "segmentation": rle,
130
+ "score": scores[k],
131
+ }
132
+ for k, rle in enumerate(rles)
133
+ ]
134
+ )
135
+ return coco_results
136
+
137
+ def prepare_for_coco_keypoint(self, predictions):
138
+ coco_results = []
139
+ for original_id, prediction in predictions.items():
140
+ if len(prediction) == 0:
141
+ continue
142
+
143
+ boxes = prediction["boxes"]
144
+ boxes = convert_to_xywh(boxes).tolist()
145
+ scores = prediction["scores"].tolist()
146
+ labels = prediction["labels"].tolist()
147
+ keypoints = prediction["keypoints"]
148
+ keypoints = keypoints.flatten(start_dim=1).tolist()
149
+
150
+ coco_results.extend(
151
+ [
152
+ {
153
+ "image_id": original_id,
154
+ "category_id": labels[k],
155
+ 'keypoints': keypoint,
156
+ "score": scores[k],
157
+ }
158
+ for k, keypoint in enumerate(keypoints)
159
+ ]
160
+ )
161
+ return coco_results
162
+
163
+
164
+ def convert_to_xywh(boxes):
165
+ xmin, ymin, xmax, ymax = boxes.unbind(1)
166
+ return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
167
+
168
+
169
+ def merge(img_ids, eval_imgs):
170
+ all_img_ids = all_gather(img_ids)
171
+ all_eval_imgs = all_gather(eval_imgs)
172
+
173
+ merged_img_ids = []
174
+ for p in all_img_ids:
175
+ merged_img_ids.extend(p)
176
+
177
+ merged_eval_imgs = []
178
+ for p in all_eval_imgs:
179
+ merged_eval_imgs.append(p)
180
+
181
+ merged_img_ids = np.array(merged_img_ids)
182
+ merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
183
+
184
+ # keep only unique (and in sorted order) images
185
+ merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
186
+ merged_eval_imgs = merged_eval_imgs[..., idx]
187
+
188
+ return merged_img_ids, merged_eval_imgs
189
+
190
+
191
+ def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
192
+ img_ids, eval_imgs = merge(img_ids, eval_imgs)
193
+ img_ids = list(img_ids)
194
+ eval_imgs = list(eval_imgs.flatten())
195
+
196
+ coco_eval.evalImgs = eval_imgs
197
+ coco_eval.params.imgIds = img_ids
198
+ coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
199
+
200
+
201
+ #################################################################
202
+ # From pycocotools, just removed the prints and fixed
203
+ # a Python3 bug about unicode not defined
204
+ #################################################################
205
+
206
+
207
+ def evaluate(self):
208
+ '''
209
+ Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
210
+ :return: None
211
+ '''
212
+ # tic = time.time()
213
+ # print('Running per image evaluation...')
214
+ p = self.params
215
+ # add backward compatibility if useSegm is specified in params
216
+ if p.useSegm is not None:
217
+ p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
218
+ print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
219
+ # print('Evaluate annotation type *{}*'.format(p.iouType))
220
+ p.imgIds = list(np.unique(p.imgIds))
221
+ if p.useCats:
222
+ p.catIds = list(np.unique(p.catIds))
223
+ p.maxDets = sorted(p.maxDets)
224
+ self.params = p
225
+
226
+ self._prepare()
227
+ # loop through images, area range, max detection number
228
+ catIds = p.catIds if p.useCats else [-1]
229
+
230
+ if p.iouType == 'segm' or p.iouType == 'bbox':
231
+ computeIoU = self.computeIoU
232
+ elif p.iouType == 'keypoints':
233
+ computeIoU = self.computeOks
234
+ self.ious = {
235
+ (imgId, catId): computeIoU(imgId, catId)
236
+ for imgId in p.imgIds
237
+ for catId in catIds}
238
+
239
+ evaluateImg = self.evaluateImg
240
+ maxDet = p.maxDets[-1]
241
+ evalImgs = [
242
+ evaluateImg(imgId, catId, areaRng, maxDet)
243
+ for catId in catIds
244
+ for areaRng in p.areaRng
245
+ for imgId in p.imgIds
246
+ ]
247
+ # this is NOT in the pycocotools code, but could be done outside
248
+ evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
249
+ self._paramsEval = copy.deepcopy(self.params)
250
+ # toc = time.time()
251
+ # print('DONE (t={:0.2f}s).'.format(toc-tic))
252
+ return p.imgIds, evalImgs
253
+
254
+ #################################################################
255
+ # end of straight copy from pycocotools, just removing the prints
256
+ #################################################################
hotr/data/evaluators/hico_eval.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # HOTR official code : hotr/data/evaluators/hico_eval.py
3
+ # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4
+ # ------------------------------------------------------------------------
5
+ # Modified from QPIC (https://github.com/hitachi-rd-cv/qpic)
6
+ # Copyright (c) Hitachi, Ltd. All Rights Reserved.
7
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
8
+ # ------------------------------------------------------------------------
9
+ import numpy as np
10
+ from collections import defaultdict
11
+
12
+ class HICOEvaluator():
13
+ def __init__(self, preds, gts, rare_triplets, non_rare_triplets, correct_mat):
14
+ self.overlap_iou = 0.5
15
+ self.max_hois = 100
16
+
17
+ self.rare_triplets = rare_triplets
18
+ self.non_rare_triplets = non_rare_triplets
19
+
20
+ self.fp = defaultdict(list)
21
+ self.tp = defaultdict(list)
22
+ self.score = defaultdict(list)
23
+ self.sum_gts = defaultdict(lambda: 0)
24
+ self.gt_triplets = []
25
+
26
+ self.preds = []
27
+ for img_preds in preds:
28
+ img_preds = {k: v.to('cpu').numpy() for k, v in img_preds.items() if k != 'hoi_recognition_time'}
29
+ bboxes = [{'bbox': bbox, 'category_id': label} for bbox, label in zip(img_preds['boxes'], img_preds['labels'])]
30
+ hoi_scores = img_preds['verb_scores']
31
+ verb_labels = np.tile(np.arange(hoi_scores.shape[1]), (hoi_scores.shape[0], 1))
32
+ subject_ids = np.tile(img_preds['sub_ids'], (hoi_scores.shape[1], 1)).T
33
+ object_ids = np.tile(img_preds['obj_ids'], (hoi_scores.shape[1], 1)).T
34
+
35
+ hoi_scores = hoi_scores.ravel()
36
+ verb_labels = verb_labels.ravel()
37
+ subject_ids = subject_ids.ravel()
38
+ object_ids = object_ids.ravel()
39
+
40
+ if len(subject_ids) > 0:
41
+ object_labels = np.array([bboxes[object_id]['category_id'] for object_id in object_ids])
42
+ masks = correct_mat[verb_labels, object_labels]
43
+ hoi_scores *= masks
44
+
45
+ hois = [{'subject_id': subject_id, 'object_id': object_id, 'category_id': category_id, 'score': score} for
46
+ subject_id, object_id, category_id, score in zip(subject_ids, object_ids, verb_labels, hoi_scores)]
47
+ hois.sort(key=lambda k: (k.get('score', 0)), reverse=True)
48
+ hois = hois[:self.max_hois]
49
+ else:
50
+ hois = []
51
+
52
+ self.preds.append({
53
+ 'predictions': bboxes,
54
+ 'hoi_prediction': hois
55
+ })
56
+
57
+ self.gts = []
58
+ for img_gts in gts:
59
+ img_gts = {k: v.to('cpu').numpy() for k, v in img_gts.items() if k != 'id'}
60
+ self.gts.append({
61
+ 'annotations': [{'bbox': bbox, 'category_id': label} for bbox, label in zip(img_gts['boxes'], img_gts['labels'])],
62
+ 'hoi_annotation': [{'subject_id': hoi[0], 'object_id': hoi[1], 'category_id': hoi[2]} for hoi in img_gts['hois']]
63
+ })
64
+ for hoi in self.gts[-1]['hoi_annotation']:
65
+ triplet = (self.gts[-1]['annotations'][hoi['subject_id']]['category_id'],
66
+ self.gts[-1]['annotations'][hoi['object_id']]['category_id'],
67
+ hoi['category_id'])
68
+
69
+ if triplet not in self.gt_triplets:
70
+ self.gt_triplets.append(triplet)
71
+
72
+ self.sum_gts[triplet] += 1
73
+
74
+ def evaluate(self):
75
+ for img_id, (img_preds, img_gts) in enumerate(zip(self.preds, self.gts)):
76
+ print(f"Evaluating Score Matrix... : [{(img_id+1):>4}/{len(self.gts):<4}]" ,flush=True, end="\r")
77
+ pred_bboxes = img_preds['predictions']
78
+ gt_bboxes = img_gts['annotations']
79
+ pred_hois = img_preds['hoi_prediction']
80
+ gt_hois = img_gts['hoi_annotation']
81
+ if len(gt_bboxes) != 0:
82
+ bbox_pairs, bbox_overlaps = self.compute_iou_mat(gt_bboxes, pred_bboxes)
83
+ self.compute_fptp(pred_hois, gt_hois, bbox_pairs, pred_bboxes, bbox_overlaps)
84
+ else:
85
+ for pred_hoi in pred_hois:
86
+ triplet = [pred_bboxes[pred_hoi['subject_id']]['category_id'],
87
+ pred_bboxes[pred_hoi['object_id']]['category_id'], pred_hoi['category_id']]
88
+ if triplet not in self.gt_triplets:
89
+ continue
90
+ self.tp[triplet].append(0)
91
+ self.fp[triplet].append(1)
92
+ self.score[triplet].append(pred_hoi['score'])
93
+ print(f"[stats] Score Matrix Generation completed!! ")
94
+ map = self.compute_map()
95
+ return map
96
+
97
+ def compute_map(self):
98
+ ap = defaultdict(lambda: 0)
99
+ rare_ap = defaultdict(lambda: 0)
100
+ non_rare_ap = defaultdict(lambda: 0)
101
+ max_recall = defaultdict(lambda: 0)
102
+ for triplet in self.gt_triplets:
103
+ sum_gts = self.sum_gts[triplet]
104
+ if sum_gts == 0:
105
+ continue
106
+
107
+ tp = np.array((self.tp[triplet]))
108
+ fp = np.array((self.fp[triplet]))
109
+ if len(tp) == 0:
110
+ ap[triplet] = 0
111
+ max_recall[triplet] = 0
112
+ if triplet in self.rare_triplets:
113
+ rare_ap[triplet] = 0
114
+ elif triplet in self.non_rare_triplets:
115
+ non_rare_ap[triplet] = 0
116
+ else:
117
+ print('Warning: triplet {} is neither in rare triplets nor in non-rare triplets'.format(triplet))
118
+ continue
119
+
120
+ score = np.array(self.score[triplet])
121
+ sort_inds = np.argsort(-score)
122
+ fp = fp[sort_inds]
123
+ tp = tp[sort_inds]
124
+ fp = np.cumsum(fp)
125
+ tp = np.cumsum(tp)
126
+ rec = tp / sum_gts
127
+ prec = tp / (fp + tp)
128
+ ap[triplet] = self.voc_ap(rec, prec)
129
+ max_recall[triplet] = np.amax(rec)
130
+ if triplet in self.rare_triplets:
131
+ rare_ap[triplet] = ap[triplet]
132
+ elif triplet in self.non_rare_triplets:
133
+ non_rare_ap[triplet] = ap[triplet]
134
+ else:
135
+ print('Warning: triplet {} is neither in rare triplets nor in non-rare triplets'.format(triplet))
136
+ m_ap = np.mean(list(ap.values())) * 100 # percentage
137
+ m_ap_rare = np.mean(list(rare_ap.values())) * 100 # percentage
138
+ m_ap_non_rare = np.mean(list(non_rare_ap.values())) * 100 # percentage
139
+ m_max_recall = np.mean(list(max_recall.values()))
140
+
141
+ return {'mAP': m_ap, 'mAP rare': m_ap_rare, 'mAP non-rare': m_ap_non_rare, 'mean max recall': m_max_recall}
142
+
143
+ def voc_ap(self, rec, prec):
144
+ ap = 0.
145
+ for t in np.arange(0., 1.1, 0.1):
146
+ if np.sum(rec >= t) == 0:
147
+ p = 0
148
+ else:
149
+ p = np.max(prec[rec >= t])
150
+ ap = ap + p / 11.
151
+ return ap
152
+
153
+ def compute_fptp(self, pred_hois, gt_hois, match_pairs, pred_bboxes, bbox_overlaps):
154
+ pos_pred_ids = match_pairs.keys()
155
+ vis_tag = np.zeros(len(gt_hois))
156
+ pred_hois.sort(key=lambda k: (k.get('score', 0)), reverse=True)
157
+ if len(pred_hois) != 0:
158
+ for pred_hoi in pred_hois:
159
+ is_match = 0
160
+ if len(match_pairs) != 0 and pred_hoi['subject_id'] in pos_pred_ids and pred_hoi['object_id'] in pos_pred_ids:
161
+ pred_sub_ids = match_pairs[pred_hoi['subject_id']]
162
+ pred_obj_ids = match_pairs[pred_hoi['object_id']]
163
+ pred_sub_overlaps = bbox_overlaps[pred_hoi['subject_id']]
164
+ pred_obj_overlaps = bbox_overlaps[pred_hoi['object_id']]
165
+ pred_category_id = pred_hoi['category_id']
166
+ max_overlap = 0
167
+ max_gt_hoi = 0
168
+ for gt_hoi in gt_hois:
169
+ if gt_hoi['subject_id'] in pred_sub_ids and gt_hoi['object_id'] in pred_obj_ids \
170
+ and pred_category_id == gt_hoi['category_id']:
171
+ is_match = 1
172
+ min_overlap_gt = min(pred_sub_overlaps[pred_sub_ids.index(gt_hoi['subject_id'])],
173
+ pred_obj_overlaps[pred_obj_ids.index(gt_hoi['object_id'])])
174
+ if min_overlap_gt > max_overlap:
175
+ max_overlap = min_overlap_gt
176
+ max_gt_hoi = gt_hoi
177
+ triplet = (pred_bboxes[pred_hoi['subject_id']]['category_id'], pred_bboxes[pred_hoi['object_id']]['category_id'],
178
+ pred_hoi['category_id'])
179
+ if triplet not in self.gt_triplets:
180
+ continue
181
+ if is_match == 1 and vis_tag[gt_hois.index(max_gt_hoi)] == 0:
182
+ self.fp[triplet].append(0)
183
+ self.tp[triplet].append(1)
184
+ vis_tag[gt_hois.index(max_gt_hoi)] =1
185
+ else:
186
+ self.fp[triplet].append(1)
187
+ self.tp[triplet].append(0)
188
+ self.score[triplet].append(pred_hoi['score'])
189
+
190
+ def compute_iou_mat(self, bbox_list1, bbox_list2):
191
+ iou_mat = np.zeros((len(bbox_list1), len(bbox_list2)))
192
+ if len(bbox_list1) == 0 or len(bbox_list2) == 0:
193
+ return {}
194
+ for i, bbox1 in enumerate(bbox_list1):
195
+ for j, bbox2 in enumerate(bbox_list2):
196
+ iou_i = self.compute_IOU(bbox1, bbox2)
197
+ iou_mat[i, j] = iou_i
198
+
199
+ iou_mat_ov=iou_mat.copy()
200
+ iou_mat[iou_mat>=self.overlap_iou] = 1
201
+ iou_mat[iou_mat<self.overlap_iou] = 0
202
+
203
+ match_pairs = np.nonzero(iou_mat)
204
+ match_pairs_dict = {}
205
+ match_pair_overlaps = {}
206
+ if iou_mat.max() > 0:
207
+ for i, pred_id in enumerate(match_pairs[1]):
208
+ if pred_id not in match_pairs_dict.keys():
209
+ match_pairs_dict[pred_id] = []
210
+ match_pair_overlaps[pred_id]=[]
211
+ match_pairs_dict[pred_id].append(match_pairs[0][i])
212
+ match_pair_overlaps[pred_id].append(iou_mat_ov[match_pairs[0][i],pred_id])
213
+ return match_pairs_dict, match_pair_overlaps
214
+
215
+ def compute_IOU(self, bbox1, bbox2):
216
+ if isinstance(bbox1['category_id'], str):
217
+ bbox1['category_id'] = int(bbox1['category_id'].replace('\n', ''))
218
+ if isinstance(bbox2['category_id'], str):
219
+ bbox2['category_id'] = int(bbox2['category_id'].replace('\n', ''))
220
+ if bbox1['category_id'] == bbox2['category_id']:
221
+ rec1 = bbox1['bbox']
222
+ rec2 = bbox2['bbox']
223
+ # computing area of each rectangles
224
+ S_rec1 = (rec1[2] - rec1[0]+1) * (rec1[3] - rec1[1]+1)
225
+ S_rec2 = (rec2[2] - rec2[0]+1) * (rec2[3] - rec2[1]+1)
226
+
227
+ # computing the sum_area
228
+ sum_area = S_rec1 + S_rec2
229
+
230
+ # find the each edge of intersect rectangle
231
+ left_line = max(rec1[1], rec2[1])
232
+ right_line = min(rec1[3], rec2[3])
233
+ top_line = max(rec1[0], rec2[0])
234
+ bottom_line = min(rec1[2], rec2[2])
235
+ # judge if there is an intersect
236
+ if left_line >= right_line or top_line >= bottom_line:
237
+ return 0
238
+ else:
239
+ intersect = (right_line - left_line+1) * (bottom_line - top_line+1)
240
+ return intersect / (sum_area - intersect)
241
+ else:
242
+ return 0
hotr/data/evaluators/vcoco_eval.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) KakaoBrain, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ V-COCO evaluator that works in distributed mode.
4
+ """
5
+ import os
6
+ import numpy as np
7
+ import torch
8
+
9
+ from hotr.util.misc import all_gather
10
+ from hotr.metrics.vcoco.ap_role import APRole
11
+ from functools import partial
12
+
13
+ def init_vcoco_evaluators(human_act_name, object_act_name):
14
+ role_eval1 = APRole(act_name=object_act_name, scenario_flag=True, iou_threshold=0.5)
15
+ role_eval2 = APRole(act_name=object_act_name, scenario_flag=False, iou_threshold=0.5)
16
+
17
+ return role_eval1, role_eval2
18
+
19
+ class VCocoEvaluator(object):
20
+ def __init__(self, args):
21
+ self.img_ids = []
22
+ self.eval_imgs = []
23
+ self.role_eval1, self.role_eval2 = init_vcoco_evaluators(args.human_actions, args.object_actions)
24
+ self.num_human_act = args.num_human_act
25
+ self.action_idx = args.valid_ids
26
+
27
+ def update(self, outputs):
28
+ img_ids = list(np.unique(list(outputs.keys())))
29
+ for img_num, img_id in enumerate(img_ids):
30
+ print(f"Evaluating Score Matrix... : [{(img_num+1):>4}/{len(img_ids):<4}]" ,flush=True, end="\r")
31
+ prediction = outputs[img_id]['prediction']
32
+ target = outputs[img_id]['target']
33
+
34
+ # score with prediction
35
+ hbox, hcat, obox, ocat = list(map(lambda x: prediction[x], \
36
+ ['h_box', 'h_cat', 'o_box', 'o_cat']))
37
+
38
+ assert 'pair_score' in prediction
39
+ score = prediction['pair_score']
40
+
41
+ hbox, hcat, obox, ocat, score =\
42
+ list(map(lambda x: x.cpu().numpy(), [hbox, hcat, obox, ocat, score]))
43
+
44
+ # ground-truth
45
+ gt_h_inds = (target['labels'] == 1)
46
+ gt_h_box = target['boxes'][gt_h_inds, :4].cpu().numpy()
47
+ gt_h_act = target['inst_actions'][gt_h_inds, :self.num_human_act].cpu().numpy()
48
+
49
+ gt_p_box = target['pair_boxes'].cpu().numpy()
50
+ gt_p_act = target['pair_actions'].cpu().numpy()
51
+
52
+ score = score[self.action_idx, :, :]
53
+ gt_p_act = gt_p_act[:, self.action_idx]
54
+
55
+ self.role_eval1.add_data(hbox, obox, score, gt_h_box, gt_h_act, gt_p_box, gt_p_act)
56
+ self.role_eval2.add_data(hbox, obox, score, gt_h_box, gt_h_act, gt_p_box, gt_p_act)
57
+ self.img_ids.append(img_id)
hotr/data/transforms/transforms.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Transforms and data augmentation for both image + bbox.
4
+ """
5
+ import random
6
+
7
+ import PIL
8
+ import torch
9
+ import torchvision.transforms as T
10
+ import torchvision.transforms.functional as F
11
+
12
+ from hotr.util.box_ops import box_xyxy_to_cxcywh
13
+ from hotr.util.misc import interpolate
14
+
15
+
16
+ def crop(image, target, region):
17
+ cropped_image = F.crop(image, *region)
18
+
19
+ target = target.copy()
20
+ i, j, h, w = region
21
+
22
+ # should we do something wrt the original size?
23
+ target["size"] = torch.tensor([h, w])
24
+ max_size = torch.as_tensor([w, h], dtype=torch.float32)
25
+
26
+ fields = ["labels", "area", "iscrowd"] # add additional fields
27
+ if "inst_actions" in target.keys():
28
+ fields.append("inst_actions")
29
+
30
+ if "boxes" in target:
31
+ boxes = target["boxes"]
32
+ cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
33
+ cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
34
+ cropped_boxes = cropped_boxes.clamp(min=0)
35
+ area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
36
+ target["boxes"] = cropped_boxes.reshape(-1, 4)
37
+ target["area"] = area
38
+ fields.append("boxes")
39
+
40
+ if "pair_boxes" in target or ("sub_boxes" in target and "obj_boxes" in target):
41
+ if "pair_boxes" in target:
42
+ pair_boxes = target["pair_boxes"]
43
+ hboxes = pair_boxes[:, :4]
44
+ oboxes = pair_boxes[:, 4:]
45
+ if ("sub_boxes" in target and "obj_boxes" in target):
46
+ hboxes = target["sub_boxes"]
47
+ oboxes = target["obj_boxes"]
48
+
49
+ cropped_hboxes = hboxes - torch.as_tensor([j, i, j, i])
50
+ cropped_hboxes = torch.min(cropped_hboxes.reshape(-1, 2, 2), max_size)
51
+ cropped_hboxes = cropped_hboxes.clamp(min=0)
52
+ hboxes = cropped_hboxes.reshape(-1, 4)
53
+
54
+ obj_mask = (oboxes[:, 0] != -1)
55
+ if obj_mask.sum() != 0:
56
+ cropped_oboxes = oboxes[obj_mask] - torch.as_tensor([j, i, j, i])
57
+ cropped_oboxes = torch.min(cropped_oboxes.reshape(-1, 2, 2), max_size)
58
+ cropped_oboxes = cropped_oboxes.clamp(min=0)
59
+ oboxes[obj_mask] = cropped_oboxes.reshape(-1, 4)
60
+ else:
61
+ cropped_oboxes = oboxes
62
+
63
+ cropped_pair_boxes = torch.cat([hboxes, oboxes], dim=-1)
64
+ target["pair_boxes"] = cropped_pair_boxes
65
+ pair_fields = ["pair_boxes", "pair_actions", "pair_targets"]
66
+
67
+ if "masks" in target:
68
+ # FIXME should we update the area here if there are no boxes[?
69
+ target['masks'] = target['masks'][:, i:i + h, j:j + w]
70
+ fields.append("masks")
71
+
72
+ # remove elements for which the boxes or masks that have zero area
73
+ if "boxes" in target or "masks" in target:
74
+ # favor boxes selection when defining which elements to keep
75
+ # this is compatible with previous implementation
76
+ if "boxes" in target:
77
+ cropped_boxes = target['boxes'].reshape(-1, 2, 2)
78
+ keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
79
+ else:
80
+ keep = target['masks'].flatten(1).any(1)
81
+
82
+ for field in fields:
83
+ if field in target: # added this because there is no 'iscrowd' field in v-coco dataset
84
+ target[field] = target[field][keep]
85
+
86
+ # remove elements that have redundant area
87
+ if "boxes" in target and "labels" in target:
88
+ cropped_boxes = target['boxes']
89
+ cropped_labels = target['labels']
90
+
91
+ cnr, keep_idx = [], []
92
+ for idx, (cropped_box, cropped_lbl) in enumerate(zip(cropped_boxes, cropped_labels)):
93
+ if str((cropped_box, cropped_lbl)) not in cnr:
94
+ cnr.append(str((cropped_box, cropped_lbl)))
95
+ keep_idx.append(True)
96
+ else: keep_idx.append(False)
97
+
98
+ for field in fields:
99
+ if field in target:
100
+ target[field] = target[field][keep_idx]
101
+
102
+ # remove elements for which pair boxes have zero area
103
+ if "pair_boxes" in target:
104
+ cropped_hboxes = target["pair_boxes"][:, :4].reshape(-1, 2, 2)
105
+ cropped_oboxes = target["pair_boxes"][:, 4:].reshape(-1, 2, 2)
106
+ keep_h = torch.all(cropped_hboxes[:, 1, :] > cropped_hboxes[:, 0, :], dim=1)
107
+ keep_o = torch.all(cropped_oboxes[:, 1, :] > cropped_oboxes[:, 0, :], dim=1)
108
+ not_empty_o = torch.all(target["pair_boxes"][:, 4:] >= 0, dim=1)
109
+ discard_o = (~keep_o) & not_empty_o
110
+ if (discard_o).sum() > 0:
111
+ target["pair_boxes"][discard_o, 4:] = -1
112
+
113
+ for pair_field in pair_fields:
114
+ target[pair_field] = target[pair_field][keep_h]
115
+
116
+ return cropped_image, target
117
+
118
+
119
+ def hflip(image, target):
120
+ flipped_image = F.hflip(image)
121
+
122
+ w, h = image.size
123
+
124
+ target = target.copy()
125
+ if "boxes" in target:
126
+ boxes = target["boxes"]
127
+ boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
128
+ target["boxes"] = boxes
129
+
130
+ if "pair_boxes" in target:
131
+ pair_boxes = target["pair_boxes"]
132
+ hboxes = pair_boxes[:, :4]
133
+ oboxes = pair_boxes[:, 4:]
134
+
135
+ # human flip
136
+ hboxes = hboxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
137
+
138
+ # object flip
139
+ obj_mask = (oboxes[:, 0] != -1)
140
+ if obj_mask.sum() != 0:
141
+ o_tmp = oboxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
142
+ oboxes[obj_mask] = o_tmp[obj_mask]
143
+
144
+ pair_boxes = torch.cat([hboxes, oboxes], dim=-1)
145
+ target["pair_boxes"] = pair_boxes
146
+
147
+ if "masks" in target:
148
+ target['masks'] = target['masks'].flip(-1)
149
+
150
+ return flipped_image, target
151
+
152
+
153
+ def resize(image, target, size, max_size=None):
154
+ # size can be min_size (scalar) or (w, h) tuple
155
+
156
+ def get_size_with_aspect_ratio(image_size, size, max_size=None):
157
+ w, h = image_size
158
+ if max_size is not None:
159
+ min_original_size = float(min((w, h)))
160
+ max_original_size = float(max((w, h)))
161
+ if max_original_size / min_original_size * size > max_size:
162
+ size = int(round(max_size * min_original_size / max_original_size))
163
+
164
+ if (w <= h and w == size) or (h <= w and h == size):
165
+ return (h, w)
166
+
167
+ if w < h:
168
+ ow = size
169
+ oh = int(size * h / w)
170
+ else:
171
+ oh = size
172
+ ow = int(size * w / h)
173
+
174
+ return (oh, ow)
175
+
176
+ def get_size(image_size, size, max_size=None):
177
+ if isinstance(size, (list, tuple)):
178
+ return size[::-1]
179
+ else:
180
+ return get_size_with_aspect_ratio(image_size, size, max_size)
181
+
182
+ size = get_size(image.size, size, max_size)
183
+ rescaled_image = F.resize(image, size)
184
+
185
+ if target is None:
186
+ return rescaled_image, None
187
+
188
+ ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
189
+ ratio_width, ratio_height = ratios
190
+
191
+ target = target.copy()
192
+ if "boxes" in target:
193
+ boxes = target["boxes"]
194
+ scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
195
+ target["boxes"] = scaled_boxes
196
+
197
+ if "pair_boxes" in target:
198
+ hboxes = target["pair_boxes"][:, :4]
199
+ scaled_hboxes = hboxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
200
+ hboxes = scaled_hboxes
201
+
202
+ oboxes = target["pair_boxes"][:, 4:]
203
+ obj_mask = (oboxes[:, 0] != -1)
204
+ if obj_mask.sum() != 0:
205
+ scaled_oboxes = oboxes[obj_mask] * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
206
+ oboxes[obj_mask] = scaled_oboxes
207
+
208
+ target["pair_boxes"] = torch.cat([hboxes, oboxes], dim=-1)
209
+
210
+ if "area" in target:
211
+ area = target["area"]
212
+ scaled_area = area * (ratio_width * ratio_height)
213
+ target["area"] = scaled_area
214
+
215
+ h, w = size
216
+ target["size"] = torch.tensor([h, w])
217
+
218
+ if "masks" in target:
219
+ target['masks'] = interpolate(
220
+ target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5
221
+
222
+ return rescaled_image, target
223
+
224
+
225
+ def pad(image, target, padding):
226
+ # assumes that we only pad on the bottom right corners
227
+ padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
228
+ if target is None:
229
+ return padded_image, None
230
+ target = target.copy()
231
+ # should we do something wrt the original size?
232
+ target["size"] = torch.tensor(padded_image[::-1])
233
+ if "masks" in target:
234
+ target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1]))
235
+ return padded_image, target
236
+
237
+
238
+ class RandomCrop(object):
239
+ def __init__(self, size):
240
+ self.size = size
241
+
242
+ def __call__(self, img, target):
243
+ region = T.RandomCrop.get_params(img, self.size)
244
+ return crop(img, target, region)
245
+
246
+
247
+ class RandomSizeCrop(object):
248
+ def __init__(self, min_size: int, max_size: int):
249
+ self.min_size = min_size
250
+ self.max_size = max_size
251
+
252
+ def __call__(self, img: PIL.Image.Image, target: dict):
253
+ w = random.randint(self.min_size, min(img.width, self.max_size))
254
+ h = random.randint(self.min_size, min(img.height, self.max_size))
255
+ region = T.RandomCrop.get_params(img, [h, w])
256
+ return crop(img, target, region)
257
+
258
+
259
+ class CenterCrop(object):
260
+ def __init__(self, size):
261
+ self.size = size
262
+
263
+ def __call__(self, img, target):
264
+ image_width, image_height = img.size
265
+ crop_height, crop_width = self.size
266
+ crop_top = int(round((image_height - crop_height) / 2.))
267
+ crop_left = int(round((image_width - crop_width) / 2.))
268
+ return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
269
+
270
+
271
+ class RandomHorizontalFlip(object):
272
+ def __init__(self, p=0.5):
273
+ self.p = p
274
+
275
+ def __call__(self, img, target):
276
+ if random.random() < self.p:
277
+ return hflip(img, target)
278
+ return img, target
279
+
280
+
281
+ class RandomResize(object):
282
+ def __init__(self, sizes, max_size=None):
283
+ assert isinstance(sizes, (list, tuple))
284
+ self.sizes = sizes
285
+ self.max_size = max_size
286
+
287
+ def __call__(self, img, target=None):
288
+ size = random.choice(self.sizes)
289
+ return resize(img, target, size, self.max_size)
290
+
291
+
292
+ class RandomPad(object):
293
+ def __init__(self, max_pad):
294
+ self.max_pad = max_pad
295
+
296
+ def __call__(self, img, target):
297
+ pad_x = random.randint(0, self.max_pad)
298
+ pad_y = random.randint(0, self.max_pad)
299
+ return pad(img, target, (pad_x, pad_y))
300
+
301
+
302
+ class RandomSelect(object):
303
+ """
304
+ Randomly selects between transforms1 and transforms2,
305
+ with probability p for transforms1 and (1 - p) for transforms2
306
+ """
307
+ def __init__(self, transforms1, transforms2, p=0.5):
308
+ self.transforms1 = transforms1
309
+ self.transforms2 = transforms2
310
+ self.p = p
311
+
312
+ def __call__(self, img, target):
313
+ if random.random() < self.p:
314
+ return self.transforms1(img, target)
315
+ return self.transforms2(img, target)
316
+
317
+
318
+ class ToTensor(object):
319
+ def __call__(self, img, target):
320
+ return F.to_tensor(img), target
321
+
322
+
323
+ class RandomErasing(object):
324
+
325
+ def __init__(self, *args, **kwargs):
326
+ self.eraser = T.RandomErasing(*args, **kwargs)
327
+
328
+ def __call__(self, img, target):
329
+ return self.eraser(img), target
330
+
331
+
332
+ class Normalize(object):
333
+ def __init__(self, mean, std):
334
+ self.mean = mean
335
+ self.std = std
336
+
337
+ def __call__(self, image, target=None):
338
+ image = F.normalize(image, mean=self.mean, std=self.std)
339
+ if target is None:
340
+ return image, None
341
+ target = target.copy()
342
+ h, w = image.shape[-2:]
343
+ if "boxes" in target:
344
+ boxes = target["boxes"]
345
+ boxes = box_xyxy_to_cxcywh(boxes)
346
+ boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
347
+ target["boxes"] = boxes
348
+
349
+ if "pair_boxes" in target:
350
+ hboxes = target["pair_boxes"][:, :4]
351
+ hboxes = box_xyxy_to_cxcywh(hboxes)
352
+ hboxes = hboxes / torch.tensor([w, h, w, h], dtype=torch.float32)
353
+
354
+ oboxes = target["pair_boxes"][:, 4:]
355
+ obj_mask = (oboxes[:, 0] != -1)
356
+ if obj_mask.sum() != 0:
357
+ oboxes[obj_mask] = box_xyxy_to_cxcywh(oboxes[obj_mask])
358
+ oboxes[obj_mask] = oboxes[obj_mask] / torch.tensor([w, h, w, h], dtype=torch.float32)
359
+
360
+ pair_boxes = torch.cat([hboxes, oboxes], dim=-1)
361
+ target["pair_boxes"] = pair_boxes
362
+
363
+ return image, target
364
+
365
+ class ColorJitter(object):
366
+ def __init__(self, brightness=0, contrast=0, saturatio=0, hue=0):
367
+ self.color_jitter = T.ColorJitter(brightness, contrast, saturatio, hue)
368
+
369
+ def __call__(self, img, target):
370
+ return self.color_jitter(img), target
371
+
372
+ class Compose(object):
373
+ def __init__(self, transforms):
374
+ self.transforms = transforms
375
+
376
+ def __call__(self, image, target):
377
+ for t in self.transforms:
378
+ image, target = t(image, target)
379
+ return image, target
380
+
381
+ def __repr__(self):
382
+ format_string = self.__class__.__name__ + "("
383
+ for t in self.transforms:
384
+ format_string += "\n"
385
+ format_string += " {0}".format(t)
386
+ format_string += "\n)"
387
+ return format_string
hotr/engine/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .evaluator_vcoco import vcoco_evaluate, vcoco_accumulate
2
+ from .evaluator_hico import hico_evaluate
3
+
4
+ def hoi_evaluator(args, model, criterion, postprocessors, data_loader, device, thr=0):
5
+ if args.dataset_file == 'vcoco':
6
+ return vcoco_evaluate(model, criterion, postprocessors, data_loader, device, args.output_dir, thr,args=args)
7
+ elif args.dataset_file == 'hico-det':
8
+ return hico_evaluate(model, postprocessors, data_loader, device, thr,args=args)
9
+ else: raise NotImplementedError
10
+
11
+ def hoi_accumulator(args, total_res, print_results=False, wandb=False):
12
+ if args.dataset_file == 'vcoco':
13
+ return vcoco_accumulate(total_res, args, print_results, wandb)
14
+ else: raise NotImplementedError
hotr/engine/arg_parser.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # HOTR official code : engine/arg_parser.py
3
+ # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4
+ # Modified arguments are represented with *
5
+ # ------------------------------------------------------------------------
6
+ # Modified from DETR (https://github.com/facebookresearch/detr)
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8
+ # ------------------------------------------------------------------------
9
+ import argparse
10
+ import hotr.util.misc as utils
11
+
12
+ def get_args_parser():
13
+ parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
14
+ parser.add_argument('--lr', default=1e-4, type=float)
15
+ parser.add_argument('--lr_backbone', default=1e-5, type=float)
16
+ parser.add_argument('--batch_size', default=2, type=int)
17
+ parser.add_argument('--weight_decay', default=1e-4, type=float)
18
+ parser.add_argument('--epochs', default=100, type=int)
19
+ parser.add_argument('--lr_drop', default=80, type=int)
20
+ parser.add_argument('--clip_max_norm', default=0.1, type=float,
21
+ help='gradient clipping max norm')
22
+
23
+ # DETR Model parameters
24
+ parser.add_argument('--frozen_weights', type=str, default=None,
25
+ help="Path to the pretrained model. If set, only the mask head will be trained")
26
+ parser.add_argument('--pretrain_interaction_tf', type=str, default=None,
27
+ help="Path to the pretrained model. If set, only the mask head will be trained")
28
+
29
+ # DETR Backbone
30
+ parser.add_argument('--backbone', default='resnet50', type=str,
31
+ help="Name of the convolutional backbone to use")
32
+ parser.add_argument('--dilation', action='store_true',
33
+ help="If true, we replace stride with dilation in the last convolutional block (DC5)")
34
+ parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
35
+ help="Type of positional embedding to use on top of the image features")
36
+
37
+ # DETR Transformer (= Encoder, Instance Decoder)
38
+ parser.add_argument('--enc_layers', default=6, type=int,
39
+ help="Number of encoding layers in the transformer")
40
+ parser.add_argument('--dec_layers', default=6, type=int,
41
+ help="Number of decoding layers in the transformer")
42
+ parser.add_argument('--dim_feedforward', default=2048, type=int,
43
+ help="Intermediate size of the feedforward layers in the transformer blocks")
44
+ parser.add_argument('--hidden_dim', default=256, type=int,
45
+ help="Size of the embeddings (dimension of the transformer)")
46
+ parser.add_argument('--dropout', default=0.1, type=float,
47
+ help="Dropout applied in the transformer")
48
+ parser.add_argument('--nheads', default=8, type=int,
49
+ help="Number of attention heads inside the transformer's attentions")
50
+ parser.add_argument('--num_queries', default=100, type=int,
51
+ help="Number of query slots")
52
+ parser.add_argument('--pre_norm', action='store_true')
53
+ parser.add_argument('--decoder_form', default=2, type=int,
54
+ help="1-decoder or 2-decoder")
55
+ # Segmentation
56
+ parser.add_argument('--masks', action='store_true',
57
+ help="Train segmentation head if the flag is provided")
58
+
59
+ # Loss Option
60
+ parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
61
+ help="Disables auxiliary decoding losses (loss at each layer)")
62
+
63
+ # Loss coefficients (DETR)
64
+ parser.add_argument('--mask_loss_coef', default=1, type=float)
65
+ parser.add_argument('--dice_loss_coef', default=1, type=float)
66
+ parser.add_argument('--bbox_loss_coef', default=5, type=float)
67
+ parser.add_argument('--giou_loss_coef', default=2, type=float)
68
+ parser.add_argument('--eos_coef', default=0.1, type=float,
69
+ help="Relative classification weight of the no-object class")
70
+
71
+ # Matcher (DETR)
72
+ parser.add_argument('--set_cost_class', default=1, type=float,
73
+ help="Class coefficient in the matching cost")
74
+ parser.add_argument('--set_cost_bbox', default=5, type=float,
75
+ help="L1 box coefficient in the matching cost")
76
+ parser.add_argument('--set_cost_giou', default=2, type=float,
77
+ help="giou box coefficient in the matching cost")
78
+
79
+ # * HOI Detection
80
+ parser.add_argument('--HOIDet', action='store_true',
81
+ help="Train HOI Detection head if the flag is provided")
82
+ parser.add_argument('--share_enc', action='store_true',
83
+ help="Share the Encoder in DETR for HOI Detection if the flag is provided")
84
+ parser.add_argument('--pretrained_dec', action='store_true',
85
+ help="Use Pre-trained Decoder in DETR for Interaction Decoder if the flag is provided")
86
+ parser.add_argument('--hoi_enc_layers', default=1, type=int,
87
+ help="Number of decoding layers in HOI transformer")
88
+ parser.add_argument('--hoi_dec_layers', default=1, type=int,
89
+ help="Number of decoding layers in HOI transformer")
90
+ parser.add_argument('--hoi_nheads', default=8, type=int,
91
+ help="Number of decoding layers in HOI transformer")
92
+ parser.add_argument('--hoi_dim_feedforward', default=2048, type=int,
93
+ help="Number of decoding layers in HOI transformer")
94
+ # parser.add_argument('--hoi_mode', type=str, default=None, help='[inst | pair | all]')
95
+ parser.add_argument('--num_hoi_queries', default=100, type=int,
96
+ help="Number of Queries for Interaction Decoder")
97
+ parser.add_argument('--hoi_aux_loss', action='store_true')
98
+
99
+
100
+ # * HOTR Matcher
101
+ parser.add_argument('--set_cost_idx', default=1, type=float,
102
+ help="IDX coefficient in the matching cost")
103
+ parser.add_argument('--set_cost_act', default=1, type=float,
104
+ help="Action coefficient in the matching cost")
105
+ parser.add_argument('--set_cost_tgt', default=1, type=float,
106
+ help="Target coefficient in the matching cost")
107
+
108
+ # * HOTR Loss coefficients
109
+ parser.add_argument('--temperature', default=0.05, type=float, help="temperature")
110
+ parser.add_argument('--hoi_consistency_loss_coef', default=1, type=float)
111
+ parser.add_argument('--hoi_idx_loss_coef', default=1, type=float)
112
+ parser.add_argument('--hoi_idx_consistency_loss_coef', default=1, type=float)
113
+ parser.add_argument('--hoi_act_loss_coef', default=1, type=float)
114
+ parser.add_argument('--hoi_act_consistency_loss_coef', default=1, type=float)
115
+ parser.add_argument('--hoi_tgt_loss_coef', default=1, type=float)
116
+ parser.add_argument('--hoi_tgt_consistency_loss_coef', default=1, type=float)
117
+ parser.add_argument('--hoi_eos_coef', default=0.1, type=float, help="Relative classification weight of the no-object class")
118
+
119
+ parser.add_argument('--ramp_down_epoch',default=10000,type=int)
120
+ parser.add_argument('--ramp_up_epoch',default=0,type=int)
121
+ #consistency
122
+ parser.add_argument('--use_consis',action='store_true',help='use consistency regularization')
123
+ parser.add_argument('--share_dec_param',action='store_true',help = 'share decoder parameters of all stages')
124
+ parser.add_argument("--augpath_name", type=utils.arg_as_list,default=[],
125
+ help='choose which augmented inference paths to use. (p2:x->HO->I,p3:x->HI->O,p4:x->OI->H)')
126
+ parser.add_argument('--stop_grad_stage',action='store_true',help='Do not back propogate loss to previous stage')
127
+ parser.add_argument('--path_id', default=0, type=int)
128
+
129
+ parser.add_argument('--sep_enc_forward',action='store_true')
130
+
131
+ # * dataset parameters
132
+ parser.add_argument('--dataset_file', help='[coco | vcoco]')
133
+ parser.add_argument('--data_path', type=str)
134
+ parser.add_argument('--object_threshold', type=float, default=0, help='Threshold for object confidence')
135
+
136
+ # machine parameters
137
+ parser.add_argument('--output_dir', default='',
138
+ help='path where to save, empty for no saving')
139
+ parser.add_argument('--custom_path', default='',
140
+ help="Data path for custom inference. Only required for custom_main.py")
141
+ parser.add_argument('--device', default='cuda',
142
+ help='device to use for training / testing')
143
+ parser.add_argument('--seed', default=42, type=int)
144
+ parser.add_argument('--resume', default='', help='resume from checkpoint')
145
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
146
+ help='start epoch')
147
+ parser.add_argument('--num_workers', default=2, type=int)
148
+
149
+ # mode
150
+ parser.add_argument('--eval', action='store_true', help="Only evaluate results if the flag is provided")
151
+ parser.add_argument('--validate', action='store_true', help="Validate after every epoch")
152
+
153
+ # distributed training parameters
154
+ parser.add_argument('--world_size', default=1, type=int,
155
+ help='number of distributed processes')
156
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
157
+
158
+ # * WanDB
159
+ parser.add_argument('--wandb', action='store_true')
160
+ parser.add_argument('--project_name', default='hotr_cpc')
161
+ parser.add_argument('--group_name', default='mlv')
162
+ parser.add_argument('--run_name', default='run_000001')
163
+ return parser
hotr/engine/evaluator_coco.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import hotr.util.misc as utils
4
+ import hotr.util.logger as loggers
5
+ from hotr.data.evaluators.coco_eval import CocoEvaluator
6
+
7
+ @torch.no_grad()
8
+ def coco_evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir):
9
+ model.eval()
10
+ criterion.eval()
11
+
12
+ metric_logger = loggers.MetricLogger(delimiter=" ")
13
+ metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
14
+ header = 'Evaluation'
15
+
16
+ iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys())
17
+ coco_evaluator = CocoEvaluator(base_ds, iou_types)
18
+ print_freq = len(data_loader)
19
+ # coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75]
20
+
21
+ print("\n>>> [MS-COCO Evaluation] <<<")
22
+ for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
23
+ samples = samples.to(device)
24
+ targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
25
+
26
+ outputs = model(samples)
27
+ loss_dict = criterion(outputs, targets)
28
+ weight_dict = criterion.weight_dict
29
+
30
+ # reduce losses over all GPUs for logging purposes
31
+ loss_dict_reduced = utils.reduce_dict(loss_dict)
32
+ loss_dict_reduced_scaled = {k: v * weight_dict[k]
33
+ for k, v in loss_dict_reduced.items() if k in weight_dict}
34
+ loss_dict_reduced_unscaled = {f'{k}_unscaled': v
35
+ for k, v in loss_dict_reduced.items()}
36
+ metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
37
+ **loss_dict_reduced_scaled,
38
+ **loss_dict_reduced_unscaled)
39
+ metric_logger.update(class_error=loss_dict_reduced['class_error'])
40
+
41
+ orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
42
+ results = postprocessors['bbox'](outputs, orig_target_sizes)
43
+ res = {target['image_id'].item(): output for target, output in zip(targets, results)}
44
+ if coco_evaluator is not None:
45
+ coco_evaluator.update(res)
46
+
47
+ # gather the stats from all processes
48
+ metric_logger.synchronize_between_processes()
49
+ print("\n>>> [Averaged stats] <<<\n", metric_logger)
50
+ if coco_evaluator is not None:
51
+ coco_evaluator.synchronize_between_processes()
52
+
53
+ # accumulate predictions from all images
54
+ if coco_evaluator is not None:
55
+ coco_evaluator.accumulate()
56
+ coco_evaluator.summarize()
57
+ stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
58
+ if coco_evaluator is not None:
59
+ if 'bbox' in postprocessors.keys():
60
+ stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
61
+
62
+ return stats, coco_evaluator
hotr/engine/evaluator_hico.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import sys
4
+ from typing import Iterable
5
+ import numpy as np
6
+ import copy
7
+ import itertools
8
+
9
+ import torch
10
+
11
+ import hotr.util.misc as utils
12
+ import hotr.util.logger as loggers
13
+ from hotr.data.evaluators.hico_eval import HICOEvaluator
14
+
15
+ @torch.no_grad()
16
+ def hico_evaluate(model, postprocessors, data_loader, device, thr, args=None):
17
+ model.eval()
18
+
19
+ metric_logger = loggers.MetricLogger(mode="test", delimiter=" ")
20
+ header = 'Evaluation Inference (HICO-DET)'
21
+
22
+ preds = []
23
+ gts = []
24
+ indices = []
25
+ hoi_recognition_time = []
26
+
27
+ for samples, targets in metric_logger.log_every(data_loader, 50, header):
28
+ samples = samples.to(device)
29
+ targets = [{k: (v.to(device) if k != 'id' else v) for k, v in t.items()} for t in targets]
30
+
31
+ outputs = model(samples)
32
+ orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
33
+ results = postprocessors['hoi'](outputs, orig_target_sizes, threshold=thr, dataset='hico-det', args=args)
34
+ hoi_recognition_time.append(results[0]['hoi_recognition_time'] * 1000)
35
+
36
+ preds.extend(list(itertools.chain.from_iterable(utils.all_gather(results))))
37
+ # For avoiding a runtime error, the copy is used
38
+ gts.extend(list(itertools.chain.from_iterable(utils.all_gather(copy.deepcopy(targets)))))
39
+
40
+ print(f"[stats] HOI Recognition Time (avg) : {sum(hoi_recognition_time)/len(hoi_recognition_time):.4f} ms")
41
+
42
+ # gather the stats from all processes
43
+ metric_logger.synchronize_between_processes()
44
+
45
+ img_ids = [img_gts['id'] for img_gts in gts]
46
+ _, indices = np.unique(img_ids, return_index=True)
47
+ preds = [img_preds for i, img_preds in enumerate(preds) if i in indices]
48
+ gts = [img_gts for i, img_gts in enumerate(gts) if i in indices]
49
+
50
+ evaluator = HICOEvaluator(preds, gts, data_loader.dataset.rare_triplets,
51
+ data_loader.dataset.non_rare_triplets, data_loader.dataset.correct_mat)
52
+
53
+ stats = evaluator.evaluate()
54
+
55
+ return stats
hotr/engine/evaluator_vcoco.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # HOTR official code : hotr/engine/evaluator_vcoco.py
3
+ # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4
+ # ------------------------------------------------------------------------
5
+ # Modified from DETR (https://github.com/facebookresearch/detr)
6
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
7
+ # ------------------------------------------------------------------------
8
+ import os
9
+ import torch
10
+ import time
11
+ import datetime
12
+
13
+ import hotr.util.misc as utils
14
+ import hotr.util.logger as loggers
15
+ from hotr.data.evaluators.vcoco_eval import VCocoEvaluator
16
+ from hotr.util.box_ops import rescale_bboxes, rescale_pairs
17
+
18
+ import wandb
19
+
20
+ @torch.no_grad()
21
+ def vcoco_evaluate(model, criterion, postprocessors, data_loader, device, output_dir, thr,args=None):
22
+ model.eval()
23
+ criterion.eval()
24
+
25
+ metric_logger = loggers.MetricLogger(mode="test", delimiter=" ")
26
+ header = 'Evaluation Inference (V-COCO)'
27
+
28
+ print_freq = 1 # len(data_loader)
29
+ res = {}
30
+ hoi_recognition_time = []
31
+
32
+ for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
33
+ samples = samples.to(device)
34
+ targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
35
+
36
+ outputs = model(samples)
37
+ loss_dict = criterion(outputs, targets)
38
+ loss_dict_reduced = utils.reduce_dict(loss_dict) # ddp gathering
39
+
40
+ orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
41
+ results = postprocessors['hoi'](outputs, orig_target_sizes, threshold=thr, dataset='vcoco',args=args)
42
+ targets = process_target(targets, orig_target_sizes)
43
+ hoi_recognition_time.append(results[0]['hoi_recognition_time'] * 1000)
44
+
45
+ res.update(
46
+ {target['image_id'].item():\
47
+ {'target': target, 'prediction': output} for target, output in zip(targets, results)
48
+ }
49
+ )
50
+ print(f"[stats] HOI Recognition Time (avg) : {sum(hoi_recognition_time)/len(hoi_recognition_time):.4f} ms")
51
+
52
+ start_time = time.time()
53
+ gather_res = utils.all_gather(res)
54
+ total_res = {}
55
+ for dist_res in gather_res:
56
+ total_res.update(dist_res)
57
+ total_time = time.time() - start_time
58
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
59
+ print(f"[stats] Distributed Gathering Time : {total_time_str}")
60
+
61
+ return total_res
62
+
63
+ def vcoco_accumulate(total_res, args, print_results, wandb_log):
64
+ vcoco_evaluator = VCocoEvaluator(args)
65
+ vcoco_evaluator.update(total_res)
66
+ print(f"[stats] Score Matrix Generation completed!! ")
67
+
68
+ scenario1 = vcoco_evaluator.role_eval1.evaluate(print_results)
69
+ scenario2 = vcoco_evaluator.role_eval2.evaluate(print_results)
70
+
71
+ if wandb_log:
72
+ wandb.log({
73
+ 'scenario1': scenario1,
74
+ 'scenario2': scenario2
75
+ })
76
+
77
+ return scenario1, scenario2
78
+
79
+ def process_target(targets, target_sizes):
80
+ for idx, (target, target_size) in enumerate(zip(targets, target_sizes)):
81
+ labels = target['labels']
82
+ valid_boxes_inds = (labels > 0)
83
+
84
+ targets[idx]['boxes'] = rescale_bboxes(target['boxes'], target_size) # boxes
85
+ targets[idx]['pair_boxes'] = rescale_pairs(target['pair_boxes'], target_size) # pairs
86
+
87
+ return targets
hotr/engine/trainer.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # HOTR official code : engine/trainer.py
3
+ # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4
+ # ------------------------------------------------------------------------
5
+ # Modified from DETR (https://github.com/facebookresearch/detr)
6
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
7
+ # ------------------------------------------------------------------------
8
+ import math
9
+ import torch
10
+ import sys
11
+ import hotr.util.misc as utils
12
+ import hotr.util.logger as loggers
13
+ from hotr.util.ramp import *
14
+ from typing import Iterable
15
+ import wandb
16
+
17
+ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
18
+ data_loader: Iterable, optimizer: torch.optim.Optimizer,
19
+ device: torch.device, epoch: int, max_epoch: int, ramp_up_epoch: int,rampdown_epoch: int,max_consis_coef: float=1.0,max_norm: float = 0,dataset_file: str = 'coco', log: bool = False):
20
+ model.train()
21
+ criterion.train()
22
+ metric_logger = loggers.MetricLogger(mode="train", delimiter=" ")
23
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
24
+ space_fmt = str(len(str(max_epoch)))
25
+ header = 'Epoch [{start_epoch: >{fill}}/{end_epoch}]'.format(start_epoch=epoch+1, end_epoch=max_epoch, fill=space_fmt)
26
+ print_freq = int(len(data_loader)/5)
27
+
28
+ if epoch<=rampdown_epoch:
29
+ consis_coef=sigmoid_rampup(epoch,ramp_up_epoch,max_consis_coef)
30
+ else:
31
+ consis_coef=cosine_rampdown(epoch-rampdown_epoch,max_epoch-rampdown_epoch,max_consis_coef)
32
+ print(consis_coef)
33
+ print(f"\n>>> Epoch #{(epoch+1)}")
34
+ for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
35
+ samples = samples.to(device)
36
+ targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
37
+
38
+ outputs = model(samples)
39
+ loss_dict = criterion(outputs, targets, log)
40
+ #print(loss_dict)
41
+ weight_dict = criterion.weight_dict
42
+
43
+ losses = sum(loss_dict[k] * weight_dict[k]*consis_coef if 'consistency' in k else loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
44
+
45
+ # reduce losses over all GPUs for logging purposes
46
+ loss_dict_reduced = utils.reduce_dict(loss_dict)
47
+ loss_dict_reduced_unscaled = {f'{k}_unscaled': v
48
+ for k, v in loss_dict_reduced.items()}
49
+ loss_dict_reduced_scaled = {k: v * weight_dict[k]*consis_coef if 'consistency' in k else v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict}
50
+ losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
51
+ loss_value = losses_reduced_scaled.item()
52
+
53
+
54
+ if not math.isfinite(loss_value):
55
+ print("Loss is {}, stopping training".format(loss_value))
56
+ print(loss_dict_reduced)
57
+ sys.exit(1)
58
+
59
+ optimizer.zero_grad()
60
+ losses.backward()
61
+ if max_norm > 0:
62
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
63
+ optimizer.step()
64
+
65
+ metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled)
66
+ if "obj_class_error" in loss_dict:
67
+ metric_logger.update(obj_class_error=loss_dict_reduced['obj_class_error'])
68
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
69
+ # gather the stats from all processes
70
+ metric_logger.synchronize_between_processes()
71
+ if utils.get_rank() == 0 and log: wandb.log(loss_dict_reduced_scaled)
72
+ print("Averaged stats:", metric_logger)
73
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
hotr/metrics/utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ def count_parameters(model):
5
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
6
+
7
+
8
+ def compute_overlap(a, b):
9
+ if type(a) == torch.Tensor:
10
+ if len(a.shape) == 2:
11
+ area = (b[:, 2] - b[:, 0] + 1) * (b[:, 3] - b[:, 1] + 1)
12
+
13
+ iw = torch.min(a[:, 2].unsqueeze(dim=1), b[:, 2]) - torch.max(a[:, 0].unsqueeze(dim=1), b[:, 0])
14
+ ih = torch.min(a[:, 3].unsqueeze(dim=1), b[:, 3]) - torch.max(a[:, 1].unsqueeze(dim=1), b[:, 1])
15
+
16
+ iw[iw<0] = 0
17
+ ih[ih<0] = 0
18
+
19
+ ua = torch.unsqueeze((a[:, 2] - a[:, 0] + 1) * (a[:, 3] - a[:, 1] + 1), dim=1) + area - iw * ih
20
+ ua[ua < 1e-8] = 1e-8
21
+
22
+ intersection = iw * ih
23
+
24
+ return intersection / ua
25
+
26
+ elif type(a) == np.ndarray:
27
+ if len(a.shape) == 2:
28
+ area = np.expand_dims((b[:, 2] - b[:, 0] + 1) * (b[:, 3] - b[:, 1] + 1), axis=0) #(1, K)
29
+
30
+ iw = np.minimum(np.expand_dims(a[:, 2], axis=1), np.expand_dims(b[:, 2], axis=0)) \
31
+ - np.maximum(np.expand_dims(a[:, 0], axis=1), np.expand_dims(b[:, 0], axis=0)) \
32
+ + 1
33
+ ih = np.minimum(np.expand_dims(a[:, 3], axis=1), np.expand_dims(b[:, 3], axis=0)) \
34
+ - np.maximum(np.expand_dims(a[:, 1], axis=1), np.expand_dims(b[:, 1], axis=0)) \
35
+ + 1
36
+
37
+ iw[iw<0] = 0 # (N, K)
38
+ ih[ih<0] = 0 # (N, K)
39
+
40
+ intersection = iw * ih
41
+
42
+ ua = np.expand_dims((a[:, 2] - a[:, 0] + 1) * (a[:, 3] - a[:, 1] + 1), axis=1) + area - intersection
43
+ ua[ua < 1e-8] = 1e-8
44
+
45
+ return intersection / ua
46
+
47
+ elif len(a.shape) == 1:
48
+ area = np.expand_dims((b[:, 2] - b[:, 0] + 1) * (b[:, 3] - b[:, 1] + 1), axis=0) #(1, K)
49
+
50
+ iw = np.minimum(np.expand_dims([a[2]], axis=1), np.expand_dims(b[:, 2], axis=0)) \
51
+ - np.maximum(np.expand_dims([a[0]], axis=1), np.expand_dims(b[:, 0], axis=0))
52
+ ih = np.minimum(np.expand_dims([a[3]], axis=1), np.expand_dims(b[:, 3], axis=0)) \
53
+ - np.maximum(np.expand_dims([a[1]], axis=1), np.expand_dims(b[:, 1], axis=0))
54
+
55
+ iw[iw<0] = 0 # (N, K)
56
+ ih[ih<0] = 0 # (N, K)
57
+
58
+ ua = np.expand_dims([(a[2] - a[0] + 1) * (a[3] - a[1] + 1)], axis=1) + area - iw * ih
59
+ ua[ua < 1e-8] = 1e-8
60
+
61
+ intersection = iw * ih
62
+
63
+ return intersection / ua
64
+
65
+
66
+ def _compute_ap(recall, precision):
67
+ """ Compute the average precision, given the recall and precision curves.
68
+ Code originally from https://github.com/rbgirshick/py-faster-rcnn.
69
+ # Arguments
70
+ recall: The recall curve (list).
71
+ precision: The precision curve (list).
72
+ # Returns
73
+ The average precision as computed in py-faster-rcnn.
74
+ """
75
+ # correct AP calculation
76
+ # first append sentinel values at the end
77
+ mrec = np.concatenate(([0.], recall, [1.]))
78
+ mpre = np.concatenate(([0.], precision, [0.]))
79
+
80
+ # compute the precision envelope
81
+ for i in range(mpre.size - 1, 0, -1):
82
+ mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
83
+
84
+ # to calculate area under PR curve, look for points
85
+ # where X axis (recall) changes value
86
+ i = np.where(mrec[1:] != mrec[:-1])[0]
87
+
88
+ # and sum (\Delta recall) * prec
89
+ ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
90
+ return ap
hotr/metrics/vcoco/ap_agent.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from hotr.metrics.utils import _compute_ap, compute_overlap
3
+ import pdb
4
+
5
+ class APAgent(object):
6
+ def __init__(self, act_name, iou_threshold=0.5):
7
+ self.act_name = act_name
8
+ self.iou_threshold = iou_threshold
9
+
10
+ self.fp = [np.zeros((0,))] * len(act_name)
11
+ self.tp = [np.zeros((0,))] * len(act_name)
12
+ self.score = [np.zeros((0,))] * len(act_name)
13
+ self.num_ann = [0] * len(act_name)
14
+
15
+ def add_data(self, box, act, cat, i_box, i_act):
16
+ for label in range(len(self.act_name)):
17
+ i_inds = (i_act[:, label] == 1)
18
+ self.num_ann[label] += i_inds.sum()
19
+
20
+ n_pred = box.shape[0]
21
+ if n_pred == 0 : return
22
+
23
+ ######################
24
+ valid_i_inds = (i_act[:, 0] != -1) # (n_i, ) # both in COCO & V-COCO
25
+
26
+ overlaps = compute_overlap(box, i_box) # (n_pred, n_i)
27
+ assigned_input = np.argmax(overlaps, axis=1) # (n_pred, )
28
+ v_inds = valid_i_inds[assigned_input] # (n_pred, )
29
+
30
+ n_valid = v_inds.sum()
31
+
32
+ if n_valid == 0 : return
33
+ valid_box = box[v_inds]
34
+ valid_act = act[v_inds]
35
+ valid_cat = cat[v_inds]
36
+
37
+ ######################
38
+ s = valid_act * np.expand_dims(valid_cat, axis=1) # (n_v, #act)
39
+
40
+ for label in range(len(self.act_name)):
41
+ inds = np.argsort(s[:, label])[::-1] # (n_v, )
42
+ self.score[label] = np.append(self.score[label], s[inds, label])
43
+
44
+ correct_i_inds = (i_act[:, label] == 1)
45
+ if correct_i_inds.sum() == 0:
46
+ self.tp[label] = np.append(self.tp[label], np.array([0]*n_valid))
47
+ self.fp[label] = np.append(self.fp[label], np.array([1]*n_valid))
48
+ continue
49
+
50
+ overlaps = compute_overlap(valid_box[inds], i_box) # (n_v, n_i)
51
+ assigned_input = np.argmax(overlaps, axis=1) # (n_v, )
52
+ max_overlap = overlaps[range(n_valid), assigned_input] # (n_v, )
53
+
54
+ iou_inds = (max_overlap > self.iou_threshold) & correct_i_inds[assigned_input] # (n_v, )
55
+
56
+ i_nonzero = iou_inds.nonzero()[0]
57
+ i_inds = assigned_input[i_nonzero]
58
+ i_iou = np.unique(i_inds, return_index=True)[1]
59
+ i_tp = i_nonzero[i_iou]
60
+
61
+ t = np.zeros(n_valid, dtype=np.uint8)
62
+ t[i_tp] = 1
63
+ f = 1-t
64
+
65
+ self.tp[label] = np.append(self.tp[label], t)
66
+ self.fp[label] = np.append(self.fp[label], f)
67
+
68
+ def evaluate(self):
69
+ average_precisions = dict()
70
+ for label in range(len(self.act_name)):
71
+ if self.num_ann[label] == 0:
72
+ average_precisions[label] = 0
73
+ continue
74
+
75
+ # sort by score
76
+ indices = np.argsort(-self.score[label])
77
+ self.fp[label] = self.fp[label][indices]
78
+ self.tp[label] = self.tp[label][indices]
79
+
80
+ # compute false positives and true positives
81
+ self.fp[label] = np.cumsum(self.fp[label])
82
+ self.tp[label] = np.cumsum(self.tp[label])
83
+
84
+ # compute recall and precision
85
+ recall = self.tp[label] / self.num_ann[label]
86
+ precision = self.tp[label] / np.maximum(self.tp[label] + self.fp[label], np.finfo(np.float64).eps)
87
+
88
+ # compute average precision
89
+ average_precisions[label] = _compute_ap(recall, precision) * 100
90
+
91
+ print('\n================== AP (Agent) ===================')
92
+ s, n = 0, 0
93
+
94
+ for label in range(len(self.act_name)):
95
+ label_name = "_".join(self.act_name[label].split("_")[1:])
96
+ print('{: >23}: AP = {:0.2f} (#pos = {:d})'.format(label_name, average_precisions[label], self.num_ann[label]))
97
+ s += average_precisions[label]
98
+ n += 1
99
+
100
+ mAP = s/n
101
+ print('| mAP(agent): {:0.2f}'.format(mAP))
102
+ print('----------------------------------------------------')
103
+
104
+ return mAP
hotr/metrics/vcoco/ap_role.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from hotr.metrics.utils import _compute_ap, compute_overlap
4
+
5
+ class APRole(object):
6
+ def __init__(self, act_name, scenario_flag=True, iou_threshold=0.5):
7
+ self.act_name = act_name
8
+ self.iou_threshold = iou_threshold
9
+
10
+ self.scenario_flag = scenario_flag
11
+ # scenario_1 : True
12
+ # scenario_2 : False
13
+
14
+ self.fp = [np.zeros((0,))] * len(act_name)
15
+ self.tp = [np.zeros((0,))] * len(act_name)
16
+ self.score = [np.zeros((0,))] * len(act_name)
17
+ self.num_ann = [0] * len(act_name)
18
+
19
+ def add_data(self, h_box, o_box, score, i_box, i_act, p_box, p_act):
20
+ # i_box, i_act : to check if only in COCO
21
+ for label in range(len(self.act_name)):
22
+ p_inds = (p_act[:, label] == 1)
23
+ self.num_ann[label] += p_inds.sum()
24
+
25
+ if h_box.shape[0] == 0 : return # if no prediction, just return
26
+ # COCO (O), V-COCO (X) __or__ collater, no ann in image => ignore
27
+
28
+ valid_i_inds = (i_act[:, 0] != -1) # (n_i, )
29
+ overlaps = compute_overlap(h_box, i_box) # (n_h, n_i)
30
+ assigned_input = np.argmax(overlaps, axis=1) # (n_h, )
31
+ v_inds = valid_i_inds[assigned_input] # (n_h, )
32
+
33
+ h_box = h_box[v_inds]
34
+ score = score[:, v_inds, :]
35
+ if h_box.shape[0] == 0 : return
36
+ n_h = h_box.shape[0]
37
+
38
+ valid_p_inds = (p_act[:, 0] != -1) | (p_box[:, 0] != -1)
39
+ p_act = p_act[valid_p_inds]
40
+ p_box = p_box[valid_p_inds]
41
+
42
+ n_o = o_box.shape[0]
43
+ if n_o == 0:
44
+ # no prediction for object
45
+ score = score.squeeze(axis=2) # (#act, n_h)
46
+
47
+ for label in range(len(self.act_name)):
48
+ h_inds = np.argsort(score[label])[::-1] # (n_h, )
49
+ self.score[label] = np.append(self.score[label], score[label, h_inds])
50
+
51
+ p_inds = (p_act[:, label] == 1)
52
+ if p_inds.sum() == 0:
53
+ self.tp[label] = np.append(self.tp[label], np.array([0]*n_h))
54
+ self.fp[label] = np.append(self.fp[label], np.array([1]*n_h))
55
+ continue
56
+
57
+ h_overlaps = compute_overlap(h_box[h_inds], p_box[p_inds, :4]) # (n_h, n_p)
58
+ assigned_p = np.argmax(h_overlaps, axis=1) # (n_h, )
59
+ h_max_overlap = h_overlaps[range(n_h), assigned_p] # (n_h, )
60
+
61
+ o_overlaps = compute_overlap(np.zeros((n_h, 4)), p_box[p_inds][assigned_p, 4:8])
62
+ o_overlaps = np.diag(o_overlaps) # (n_h, )
63
+
64
+ no_role_inds = (p_box[p_inds][assigned_p, 4] == -1) # (n_h, )
65
+ # human (o), action (o), no object in actual image
66
+
67
+ h_iou_inds = (h_max_overlap > self.iou_threshold) # (n_h, )
68
+ o_iou_inds = (o_overlaps > self.iou_threshold) # (n_h, )
69
+
70
+ # scenario1 is not considered (already no object)
71
+ o_iou_inds[no_role_inds] = 1
72
+
73
+ iou_inds = (h_iou_inds & o_iou_inds)
74
+ p_nonzero = iou_inds.nonzero()[0]
75
+ p_inds = assigned_p[p_nonzero]
76
+ p_iou = np.unique(p_inds, return_index=True)[1]
77
+ p_tp = p_nonzero[p_iou]
78
+
79
+ t = np.zeros(n_h, dtype=np.uint8)
80
+ t[p_tp] = 1
81
+ f = 1-t
82
+
83
+ self.tp[label] = np.append(self.tp[label], t)
84
+ self.fp[label] = np.append(self.fp[label], f)
85
+
86
+ else:
87
+ s_obj_argmax = np.argmax(score.reshape(-1, n_o), axis=1).reshape(-1, n_h) # (#act, n_h)
88
+ s_obj_max = np.max(score.reshape(-1, n_o), axis=1).reshape(-1, n_h) # (#act, n_h)
89
+
90
+ h_overlaps = compute_overlap(h_box, p_box[:, :4]) # (n_h, n_p)
91
+
92
+ for label in range(len(self.act_name)):
93
+ h_inds = np.argsort(s_obj_max[label])[::-1] # (n_h, )
94
+ self.score[label] = np.append(self.score[label], s_obj_max[label, h_inds])
95
+
96
+ p_inds = (p_act[:, label] == 1) # (n_p, )
97
+ if p_inds.sum() == 0:
98
+ self.tp[label] = np.append(self.tp[label], np.array([0]*n_h))
99
+ self.fp[label] = np.append(self.fp[label], np.array([1]*n_h))
100
+ continue
101
+
102
+ h_overlaps = compute_overlap(h_box[h_inds], p_box[:, :4]) # (n_h, n_p) # match for all hboxes
103
+ h_max_overlap = np.max(h_overlaps, axis=1) # (n_h, ) # get the max overlap for hbox
104
+
105
+ # for same human, multiple pairs exist. find the human box that has the same idx with max overlap hbox.
106
+ h_max_temp = np.expand_dims(h_max_overlap, axis=1)
107
+ h_over_thresh = (h_overlaps == h_max_temp) # (n_h, n_p)
108
+ h_over_thresh = h_over_thresh & np.expand_dims(p_inds, axis=0) # (n_h, n_p) # find only for current act
109
+
110
+ h_valid = h_over_thresh.sum(axis=1)>0 # (n_h, ) # at least one is True
111
+ # h_valid -> if all is False, then argmax becomes 0. <- prevent
112
+ assigned_p = np.argmax(h_over_thresh, axis=1) # (n_h, ) # p only for current act
113
+
114
+ o_mapping_box = o_box[s_obj_argmax[label]][h_inds] # (n_h, ) # find where T is.
115
+ p_mapping_box = p_box[assigned_p, 4:8] # (n_h, 4)
116
+
117
+ o_overlaps = compute_overlap(o_mapping_box, p_mapping_box)
118
+ o_overlaps = np.diag(o_overlaps) # (n_h, )
119
+ o_overlaps.setflags(write=1)
120
+ if (~h_valid).sum() > 0:
121
+ o_overlaps[~h_valid] = 0 # (n_h, )
122
+
123
+ no_role_inds = (p_box[assigned_p, 4] == -1) # (n_h, )
124
+ nan_box_inds = np.all(o_mapping_box == 0, axis=1) | np.all(np.isnan(o_mapping_box), axis=1)
125
+ no_role_inds = no_role_inds & h_valid
126
+ nan_box_inds = nan_box_inds & h_valid
127
+
128
+ h_iou_inds = (h_max_overlap > self.iou_threshold) # (n_h, )
129
+ o_iou_inds = (o_overlaps > self.iou_threshold) # (n_h, )
130
+
131
+ if self.scenario_flag: # scenario_1
132
+ o_iou_inds[no_role_inds & nan_box_inds] = 1
133
+ o_iou_inds[no_role_inds & ~nan_box_inds] = 0
134
+ else: # scenario_2
135
+ o_iou_inds[no_role_inds] = 1
136
+
137
+ iou_inds = (h_iou_inds & o_iou_inds)
138
+ p_nonzero = iou_inds.nonzero()[0]
139
+ p_inds = assigned_p[p_nonzero]
140
+ p_iou = np.unique(p_inds, return_index=True)[1]
141
+ p_tp = p_nonzero[p_iou]
142
+
143
+ t = np.zeros(n_h, dtype=np.uint8)
144
+ t[p_tp] = 1
145
+ f = 1-t
146
+
147
+ self.tp[label] = np.append(self.tp[label], t)
148
+ self.fp[label] = np.append(self.fp[label], f)
149
+
150
+ def evaluate(self, print_log=False):
151
+ average_precisions = dict()
152
+ role_num = 1 if self.scenario_flag else 2
153
+ for label in range(len(self.act_name)):
154
+
155
+ # sort by score
156
+ indices = np.argsort(-self.score[label])
157
+ self.fp[label] = self.fp[label][indices]
158
+ self.tp[label] = self.tp[label][indices]
159
+
160
+
161
+ if self.num_ann[label] == 0:
162
+ average_precisions[label] = 0
163
+ continue
164
+
165
+ # compute false positives and true positives
166
+ self.fp[label] = np.cumsum(self.fp[label])
167
+ self.tp[label] = np.cumsum(self.tp[label])
168
+
169
+ # compute recall and precision
170
+ recall = self.tp[label] / self.num_ann[label]
171
+ precision = self.tp[label] / np.maximum(self.tp[label] + self.fp[label], np.finfo(np.float64).eps)
172
+
173
+ # compute average precision
174
+ average_precisions[label] = _compute_ap(recall, precision) * 100
175
+
176
+ if print_log: print(f'\n============= AP (Role scenario_{role_num}) ==============')
177
+ s, n = 0, 0
178
+
179
+ for label in range(len(self.act_name)):
180
+ if 'point' in self.act_name[label]:
181
+ continue
182
+ label_name = "_".join(self.act_name[label].split("_")[1:])
183
+ if print_log: print('{: >23}: AP = {:0.2f} (#pos = {:d})'.format(label_name, average_precisions[label], self.num_ann[label]))
184
+ if self.num_ann[label] != 0 :
185
+ s += average_precisions[label]
186
+ n += 1
187
+
188
+ mAP = s/n
189
+ if print_log:
190
+ print('| mAP(role scenario_{:d}): {:0.2f}'.format(role_num, mAP))
191
+ print('----------------------------------------------------')
192
+
193
+ return mAP
hotr/models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ from .detr import build
3
+
4
+ def build_model(args):
5
+ return build(args)
hotr/models/backbone.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Backbone modules.
4
+ """
5
+ from collections import OrderedDict
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchvision
10
+ from torch import nn
11
+ from torchvision.models._utils import IntermediateLayerGetter
12
+ from typing import Dict, List
13
+
14
+ from hotr.util.misc import NestedTensor, is_main_process
15
+
16
+ from .position_encoding import build_position_encoding
17
+
18
+
19
+ class FrozenBatchNorm2d(torch.nn.Module):
20
+ """
21
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
22
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
23
+ without which any other models than torchvision.models.resnet[18,34,50,101]
24
+ produce nans.
25
+ """
26
+
27
+ def __init__(self, n):
28
+ super(FrozenBatchNorm2d, self).__init__()
29
+ self.register_buffer("weight", torch.ones(n))
30
+ self.register_buffer("bias", torch.zeros(n))
31
+ self.register_buffer("running_mean", torch.zeros(n))
32
+ self.register_buffer("running_var", torch.ones(n))
33
+
34
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
35
+ missing_keys, unexpected_keys, error_msgs):
36
+ num_batches_tracked_key = prefix + 'num_batches_tracked'
37
+ if num_batches_tracked_key in state_dict:
38
+ del state_dict[num_batches_tracked_key]
39
+
40
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
41
+ state_dict, prefix, local_metadata, strict,
42
+ missing_keys, unexpected_keys, error_msgs)
43
+
44
+ def forward(self, x):
45
+ # move reshapes to the beginning
46
+ # to make it fuser-friendly
47
+ w = self.weight.reshape(1, -1, 1, 1)
48
+ b = self.bias.reshape(1, -1, 1, 1)
49
+ rv = self.running_var.reshape(1, -1, 1, 1)
50
+ rm = self.running_mean.reshape(1, -1, 1, 1)
51
+ eps = 1e-5
52
+ scale = w * (rv + eps).rsqrt()
53
+ bias = b - rm * scale
54
+ return x * scale + bias
55
+
56
+
57
+ class BackboneBase(nn.Module):
58
+
59
+ def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
60
+ super().__init__()
61
+ for name, parameter in backbone.named_parameters():
62
+ if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
63
+ parameter.requires_grad_(False)
64
+ if return_interm_layers:
65
+ return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
66
+ else:
67
+ return_layers = {'layer4': "0"}
68
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
69
+ self.num_channels = num_channels
70
+
71
+ def forward(self, tensor_list: NestedTensor):
72
+ xs = self.body(tensor_list.tensors)
73
+ out: Dict[str, NestedTensor] = {}
74
+ for name, x in xs.items():
75
+ m = tensor_list.mask
76
+ assert m is not None
77
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
78
+ out[name] = NestedTensor(x, mask)
79
+ return out
80
+
81
+
82
+ class Backbone(BackboneBase):
83
+ """ResNet backbone with frozen BatchNorm."""
84
+ def __init__(self, name: str,
85
+ train_backbone: bool,
86
+ return_interm_layers: bool,
87
+ dilation: bool):
88
+ backbone = getattr(torchvision.models, name)(
89
+ replace_stride_with_dilation=[False, False, dilation],
90
+ pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
91
+ num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
92
+ super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
93
+
94
+
95
+ class Joiner(nn.Sequential):
96
+ def __init__(self, backbone, position_embedding):
97
+ super().__init__(backbone, position_embedding)
98
+
99
+ def forward(self, tensor_list: NestedTensor):
100
+ xs = self[0](tensor_list)
101
+ out: List[NestedTensor] = []
102
+ pos = []
103
+ for name, x in xs.items():
104
+ out.append(x)
105
+ # position encoding
106
+ pos.append(self[1](x).to(x.tensors.dtype))
107
+
108
+ return out, pos
109
+
110
+
111
+ def build_backbone(args):
112
+ position_embedding = build_position_encoding(args)
113
+ train_backbone = args.lr_backbone > 0
114
+ return_interm_layers = False # args.masks
115
+ backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
116
+ model = Joiner(backbone, position_embedding)
117
+ model.num_channels = backbone.num_channels
118
+ return model
hotr/models/criterion.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # HOTR official code : main.py
3
+ # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4
+ # ------------------------------------------------------------------------
5
+ # Modified from DETR (https://github.com/facebookresearch/detr)
6
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
7
+ # ------------------------------------------------------------------------
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import copy
11
+ import numpy as np
12
+ import itertools
13
+ from torch import nn
14
+
15
+ from hotr.util import box_ops
16
+ from hotr.util.misc import (accuracy, get_world_size, is_dist_avail_and_initialized)
17
+
18
+ class SetCriterion(nn.Module):
19
+ """ This class computes the loss for DETR.
20
+ The process happens in two steps:
21
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
22
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
23
+ """
24
+ def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses, num_actions=None, HOI_losses=None, HOI_matcher=None, args=None):
25
+ """ Create the criterion.
26
+ Parameters:
27
+ num_classes: number of object categories, omitting the special no-object category
28
+ matcher: module able to compute a matching between targets and proposals
29
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
30
+ eos_coef: relative classification weight applied to the no-object category
31
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
32
+ """
33
+ super().__init__()
34
+ self.num_classes = num_classes
35
+ self.matcher = matcher
36
+ self.weight_dict = weight_dict
37
+ self.losses = losses
38
+ self.eos_coef=eos_coef
39
+
40
+ self.HOI_losses = HOI_losses
41
+ self.HOI_matcher = HOI_matcher
42
+ self.use_consis=args.use_consis & len(args.augpath_name)>0
43
+ self.num_path = 1+len(args.augpath_name)
44
+ if args:
45
+ self.HOI_eos_coef = args.hoi_eos_coef
46
+ if args.dataset_file == 'vcoco':
47
+ self.invalid_ids = args.invalid_ids
48
+ self.valid_ids = np.concatenate((args.valid_ids,[-1]), axis=0) # no interaction
49
+ elif args.dataset_file == 'hico-det':
50
+ self.invalid_ids = []
51
+ self.valid_ids = list(range(num_actions)) + [-1]
52
+
53
+ # for targets
54
+ self.num_tgt_classes = len(args.valid_obj_ids)
55
+ tgt_empty_weight = torch.ones(self.num_tgt_classes + 1)
56
+ tgt_empty_weight[-1] = self.HOI_eos_coef
57
+ self.register_buffer('tgt_empty_weight', tgt_empty_weight)
58
+ self.dataset_file = args.dataset_file
59
+
60
+ empty_weight = torch.ones(self.num_classes + 1)
61
+ empty_weight[-1] = eos_coef
62
+ self.register_buffer('empty_weight', empty_weight)
63
+
64
+ #######################################################################################################################
65
+ # * DETR Losses
66
+ #######################################################################################################################
67
+ def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
68
+ """Classification loss (NLL)
69
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
70
+ """
71
+ assert 'pred_logits' in outputs
72
+ src_logits = outputs['pred_logits']
73
+
74
+ idx = self._get_src_permutation_idx(indices)
75
+ target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
76
+ target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device)
77
+ target_classes[idx] = target_classes_o
78
+
79
+ loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
80
+ losses = {'loss_ce': loss_ce}
81
+
82
+ if log:
83
+ # TODO this should probably be a separate loss, not hacked in this one here
84
+ losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
85
+ return losses
86
+
87
+ @torch.no_grad()
88
+ def loss_cardinality(self, outputs, targets, indices, num_boxes):
89
+ """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
90
+ This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
91
+ """
92
+ pred_logits = outputs['pred_logits']
93
+ device = pred_logits.device
94
+ tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
95
+ # Count the number of predictions that are NOT "no-object" (which is the last class)
96
+ card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
97
+ card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
98
+ losses = {'cardinality_error': card_err}
99
+ return losses
100
+
101
+ def loss_boxes(self, outputs, targets, indices, num_boxes):
102
+ """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
103
+ targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
104
+ The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
105
+ """
106
+ assert 'pred_boxes' in outputs
107
+ idx = self._get_src_permutation_idx(indices)
108
+ src_boxes = outputs['pred_boxes'][idx]
109
+ target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
110
+
111
+ loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
112
+
113
+ losses = {}
114
+ losses['loss_bbox'] = loss_bbox.sum() / num_boxes
115
+
116
+ loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
117
+ box_ops.box_cxcywh_to_xyxy(src_boxes),
118
+ box_ops.box_cxcywh_to_xyxy(target_boxes)))
119
+ losses['loss_giou'] = loss_giou.sum() / num_boxes
120
+ return losses
121
+
122
+
123
+ #######################################################################################################################
124
+ # * HOTR Losses
125
+ #######################################################################################################################
126
+ # >>> HOI Losses 1 : HO Pointer
127
+ def loss_pair_labels(self, outputs, targets, hoi_indices, num_boxes,use_consis, log=False):
128
+ assert ('pred_hidx' in outputs and 'pred_oidx' in outputs)
129
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
130
+ nu,q,hd=outputs['pred_hidx'].shape
131
+ src_hidx = outputs['pred_hidx'].view(self.num_path,nu//self.num_path,q,-1).transpose(0,1).flatten(0,1)
132
+ src_oidx = outputs['pred_oidx'].view(self.num_path,nu//self.num_path,q,-1).transpose(0,1).flatten(0,1)
133
+ hoi_ind=list(itertools.chain.from_iterable(hoi_indices))
134
+
135
+ idx = self._get_src_permutation_idx(hoi_ind)
136
+
137
+ target_hidx_classes = torch.full(src_hidx.shape[:2], -1, dtype=torch.int64, device=src_hidx.device)
138
+ target_oidx_classes = torch.full(src_oidx.shape[:2], -1, dtype=torch.int64, device=src_oidx.device)
139
+
140
+ # H Pointer loss
141
+ target_classes_h = torch.cat([t["h_labels"][J] for t, hoi_indice in zip(targets, hoi_indices) for (_,J) in hoi_indice])
142
+ target_hidx_classes[idx] = target_classes_h
143
+
144
+ # O Pointer loss
145
+ target_classes_o = torch.cat([t["o_labels"][J] for t, hoi_indice in zip(targets, hoi_indices) for (_,J) in hoi_indice])
146
+ target_oidx_classes[idx] = target_classes_o
147
+
148
+ loss_h = F.cross_entropy(src_hidx.transpose(1, 2), target_hidx_classes, ignore_index=-1)
149
+ loss_o = F.cross_entropy(src_oidx.transpose(1, 2), target_oidx_classes, ignore_index=-1)
150
+
151
+ #Consistency loss
152
+ if use_consis:
153
+ consistency_idxs=[self._get_consistency_src_permutation_idx(hoi_indice) for hoi_indice in hoi_indices ]
154
+ src_hidx_inputs=[F.softmax(src_hidx.view(-1,self.num_path,q,hd)[i][consistency_idx[0]],-1) for i,consistency_idx in enumerate(consistency_idxs)]
155
+ src_hidx_targets=[F.softmax(src_hidx.view(-1,self.num_path,q,hd)[i][consistency_idx[1]],-1) for i,consistency_idx in enumerate(consistency_idxs)]
156
+ src_oidx_inputs=[F.softmax(src_oidx.view(-1,self.num_path,q,hd)[i][consistency_idx[0]],-1) for i,consistency_idx in enumerate(consistency_idxs)]
157
+ src_oidx_targets=[F.softmax(src_oidx.view(-1,self.num_path,q,hd)[i][consistency_idx[1]],-1) for i,consistency_idx in enumerate(consistency_idxs)]
158
+
159
+ loss_h_consistency=[0.5*(F.kl_div(src_hidx_input.log(),src_hidx_target.clone().detach(),reduction='batchmean')+F.kl_div(src_hidx_target.log(),src_hidx_input.clone().detach(),reduction='batchmean')) for src_hidx_input,src_hidx_target in zip(src_hidx_inputs,src_hidx_targets)]
160
+ loss_o_consistency=[0.5*(F.kl_div(src_oidx_input.log(),src_oidx_target.clone().detach(),reduction='batchmean')+F.kl_div(src_oidx_target.log(),src_oidx_input.clone().detach(),reduction='batchmean')) for src_oidx_input,src_oidx_target in zip(src_oidx_inputs,src_oidx_targets)]
161
+
162
+ loss_h_consistency=torch.mean(torch.stack(loss_h_consistency))
163
+ loss_o_consistency=torch.mean(torch.stack(loss_o_consistency))
164
+
165
+ losses = {'loss_hidx': loss_h, 'loss_oidx': loss_o,'loss_h_consistency':loss_h_consistency,'loss_o_consistency':loss_o_consistency}
166
+ else:
167
+ losses = {'loss_hidx': loss_h, 'loss_oidx': loss_o}
168
+
169
+ return losses
170
+
171
+ # >>> HOI Losses 2 : pair actions
172
+ def loss_pair_actions(self, outputs, targets, hoi_indices, num_boxes,use_consis):
173
+ assert 'pred_actions' in outputs
174
+ src_actions = outputs['pred_actions'].flatten(end_dim=1)
175
+ hoi_ind=list(itertools.chain.from_iterable(hoi_indices))
176
+ # idx = self._get_src_permutation_idx(hoi_indices)
177
+ idx = self._get_src_permutation_idx(hoi_ind)
178
+
179
+ # Construct Target --------------------------------------------------------------------------------------------------------------
180
+ target_classes_o = torch.cat([t["pair_actions"][J] for t, hoi_indice in zip(targets, hoi_indices) for (_,J) in hoi_indice])
181
+ target_classes = torch.full(src_actions.shape, 0, dtype=torch.float32, device=src_actions.device)
182
+ target_classes[..., -1] = 1 # the last index for no-interaction is '1' if a label exists
183
+
184
+ pos_classes = torch.full(target_classes[idx].shape, 0, dtype=torch.float32, device=src_actions.device) # else, the last index for no-interaction is '0'
185
+ pos_classes[:, :-1] = target_classes_o.float()
186
+ target_classes[idx] = pos_classes
187
+ # --------------------------------------------------------------------------------------------------------------------------------
188
+
189
+ # BCE Loss -----------------------------------------------------------------------------------------------------------------------
190
+ logits = src_actions.sigmoid()
191
+ loss_bce = F.binary_cross_entropy(logits[..., self.valid_ids], target_classes[..., self.valid_ids], reduction='none')
192
+ p_t = logits[..., self.valid_ids] * target_classes[..., self.valid_ids] + (1 - logits[..., self.valid_ids]) * (1 - target_classes[..., self.valid_ids])
193
+ loss_bce = ((1-p_t)**2 * loss_bce)
194
+ alpha_t = 0.25 * target_classes[..., self.valid_ids] + (1 - 0.25) * (1 - target_classes[..., self.valid_ids])
195
+ loss_focal = alpha_t * loss_bce
196
+ loss_act = loss_focal.sum() / max(target_classes[..., self.valid_ids[:-1]].sum(), 1)
197
+ # --------------------------------------------------------------------------------------------------------------------------------
198
+
199
+ #Consistency loss
200
+ if use_consis:
201
+ consistency_idxs=[self._get_consistency_src_permutation_idx(hoi_indice) for hoi_indice in hoi_indices]
202
+ src_action_inputs=[F.logsigmoid(outputs['pred_actions'][i][consistency_idx[0]]) for i,consistency_idx in enumerate(consistency_idxs)]
203
+ src_action_targets=[F.logsigmoid(outputs['pred_actions'][i][consistency_idx[1]]) for i,consistency_idx in enumerate(consistency_idxs)]
204
+
205
+ loss_action_consistency=[F.mse_loss(src_action_input,src_action_target) for src_action_input,src_action_target in zip(src_action_inputs,src_action_targets)]
206
+ loss_action_consistency=torch.mean(torch.stack(loss_action_consistency))
207
+ # import pdb;pdb.set_trace()
208
+ losses = {'loss_act': loss_act,'loss_act_consistency':loss_action_consistency}
209
+ else:
210
+ losses = {'loss_act': loss_act}
211
+ return losses
212
+
213
+ # HOI Losses 3 : action targets
214
+ def loss_pair_targets(self, outputs, targets, hoi_indices, num_interactions,use_consis, log=True):
215
+ assert 'pred_obj_logits' in outputs
216
+ src_logits = outputs['pred_obj_logits']
217
+ nu,q,hd=outputs['pred_obj_logits'].shape
218
+ hoi_ind=list(itertools.chain.from_iterable(hoi_indices))
219
+ idx = self._get_src_permutation_idx(hoi_ind)
220
+
221
+ target_classes_o = torch.cat([t['pair_targets'][J] for t, hoi_indice in zip(targets, hoi_indices) for (_,J) in hoi_indice])
222
+ pad_tgt = -1 # src_logits.shape[2]-1
223
+ target_classes = torch.full(src_logits.shape[:2], pad_tgt, dtype=torch.int64, device=src_logits.device)
224
+ target_classes[idx] = target_classes_o
225
+
226
+ loss_obj_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.tgt_empty_weight, ignore_index=-1)
227
+
228
+ #consistency
229
+ if use_consis:
230
+ consistency_idxs=[self._get_consistency_src_permutation_idx(hoi_indice) for hoi_indice in hoi_indices]
231
+ src_logits_inputs=[F.softmax(src_logits.view(-1,self.num_path,q,hd)[i][consistency_idx[0]],-1) for i,consistency_idx in enumerate(consistency_idxs)]
232
+ src_logits_targets=[F.softmax(src_logits.view(-1,self.num_path,q,hd)[i][consistency_idx[1]],-1) for i,consistency_idx in enumerate(consistency_idxs)]
233
+ loss_tgt_consistency=[0.5*(F.kl_div(src_logit_input.log(),src_logit_target.clone().detach(),reduction='batchmean')+F.kl_div(src_logit_target.log(),src_logit_input.clone().detach(),reduction='batchmean')) for src_logit_input,src_logit_target in zip(src_logits_inputs,src_logits_targets)]
234
+ loss_tgt_consistency=torch.mean(torch.stack(loss_tgt_consistency))
235
+ losses = {'loss_tgt': loss_obj_ce,"loss_tgt_label_consistency":loss_tgt_consistency}
236
+ else:
237
+ losses = {'loss_tgt': loss_obj_ce}
238
+ if log:
239
+ ignore_idx = (target_classes_o != -1)
240
+ losses['obj_class_error'] = 100 - accuracy(src_logits[idx][ignore_idx, :-1], target_classes_o[ignore_idx])[0]
241
+ # losses['obj_class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
242
+ return losses
243
+
244
+ def _get_src_permutation_idx(self, indices):
245
+ # permute predictions following indices
246
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
247
+ src_idx = torch.cat([src for (src, _) in indices])
248
+ return batch_idx, src_idx
249
+
250
+ def _get_consistency_src_permutation_idx(self, indices):
251
+ all_tgt=torch.cat([j for(_,j) in indices]).unique()
252
+ path_idxs=[torch.cat([torch.tensor([i]) for i,(_,t)in enumerate(indices) if (t==tgt).any()]) for tgt in all_tgt]
253
+ q_idxs=[torch.cat([s[t==tgt] for (s,t)in indices]) for tgt in all_tgt]
254
+ path_idxs=torch.cat([torch.combinations(path_idx) for path_idx in path_idxs if len(path_idx)>1])
255
+ q_idxs=torch.cat([torch.combinations(q_idx) for q_idx in q_idxs if len(q_idx)>1])
256
+
257
+ return (path_idxs[:,0],q_idxs[:,0]),(path_idxs[:,1],q_idxs[:,1])
258
+
259
+ def _get_tgt_permutation_idx(self, indices):
260
+ # permute targets following indices
261
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
262
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
263
+ return batch_idx, tgt_idx
264
+
265
+ # *****************************************************************************
266
+ # >>> DETR Losses
267
+ def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
268
+ loss_map = {
269
+ 'labels': self.loss_labels,
270
+ 'cardinality': self.loss_cardinality,
271
+ 'boxes': self.loss_boxes
272
+ }
273
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
274
+ return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
275
+
276
+ # >>> HOTR Losses
277
+ def get_HOI_loss(self, loss, outputs, targets, indices, num_boxes,use_consis, **kwargs):
278
+ loss_map = {
279
+ 'pair_labels': self.loss_pair_labels,
280
+ 'pair_actions': self.loss_pair_actions
281
+ }
282
+ if self.dataset_file == 'hico-det': loss_map['pair_targets'] = self.loss_pair_targets
283
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
284
+ return loss_map[loss](outputs, targets, indices, num_boxes,use_consis, **kwargs)
285
+ # *****************************************************************************
286
+
287
+ def forward(self, outputs, targets, log=False):
288
+ """ This performs the loss computation.
289
+ Parameters:
290
+ outputs: dict of tensors, see the output specification of the model for the format
291
+ targets: list of dicts, such that len(targets) == batch_size.
292
+ The expected keys in each dict depends on the losses applied, see each loss' doc
293
+ """
294
+ outputs_without_aux = {k: v for k, v in outputs.items() if (k != 'aux_outputs' and k != 'hoi_aux_outputs')}
295
+
296
+ # Retrieve the matching between the outputs of the last layer and the targets
297
+ indices = self.matcher(outputs_without_aux, targets)
298
+
299
+ if self.HOI_losses is not None:
300
+ input_targets = [copy.deepcopy(target) for target in targets]
301
+ hoi_indices, hoi_targets = self.HOI_matcher(outputs_without_aux, input_targets, indices, log)
302
+
303
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
304
+ num_boxes = sum(len(t["labels"]) for t in targets)
305
+ num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
306
+ if is_dist_avail_and_initialized():
307
+ torch.distributed.all_reduce(num_boxes)
308
+ num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
309
+
310
+ # Compute all the requested losses
311
+ losses = {}
312
+ for loss in self.losses:
313
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
314
+
315
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
316
+ if 'aux_outputs' in outputs:
317
+ for i, aux_outputs in enumerate(outputs['aux_outputs']):
318
+ indices = self.matcher(aux_outputs, targets)
319
+ for loss in self.losses:
320
+ if loss == 'masks':
321
+ # Intermediate masks losses are too costly to compute, we ignore them.
322
+ continue
323
+ kwargs = {}
324
+ if loss == 'labels':
325
+ # Logging is enabled only for the last layer
326
+ kwargs = {'log': False}
327
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
328
+ l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
329
+ losses.update(l_dict)
330
+
331
+ # HOI detection losses
332
+ if self.HOI_losses is not None:
333
+ for loss in self.HOI_losses:
334
+ losses.update(self.get_HOI_loss(loss, outputs, hoi_targets, hoi_indices, num_boxes,self.use_consis))
335
+ # if self.dataset_file == 'hico-det': losses['loss_oidx'] += losses['loss_tgt']
336
+
337
+ if 'hoi_aux_outputs' in outputs:
338
+ for i, aux_outputs in enumerate(outputs['hoi_aux_outputs']):
339
+ input_targets = [copy.deepcopy(target) for target in targets]
340
+ hoi_indices, targets_for_aux = self.HOI_matcher(aux_outputs, input_targets, indices, log)
341
+ for loss in self.HOI_losses:
342
+ kwargs = {}
343
+ if loss == 'pair_targets': kwargs = {'log': False} # Logging is enabled only for the last layer
344
+ l_dict = self.get_HOI_loss(loss, aux_outputs, hoi_targets, hoi_indices, num_boxes,self.use_consis, **kwargs)
345
+ l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
346
+ losses.update(l_dict)
347
+ # if self.dataset_file == 'hico-det': losses[f'loss_oidx_{i}'] += losses[f'loss_tgt_{i}']
348
+
349
+ return losses
hotr/models/detr.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # HOTR official code : hotr/models/detr.py
3
+ # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4
+ # ------------------------------------------------------------------------
5
+ # Modified from DETR (https://github.com/facebookresearch/detr)
6
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
7
+ # ------------------------------------------------------------------------
8
+ """
9
+ DETR & HOTR model and criterion classes.
10
+ """
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+
15
+ from hotr.util.misc import (NestedTensor, nested_tensor_from_tensor_list)
16
+
17
+ from .backbone import build_backbone
18
+ from .detr_matcher import build_matcher
19
+ from .hotr_matcher import build_hoi_matcher
20
+ from .transformer import build_transformer, build_hoi_transformer
21
+ from .criterion import SetCriterion
22
+ from .post_process import PostProcess
23
+ from .feed_forward import MLP
24
+
25
+ from .hotr import HOTR
26
+ from .hotr_v1 import HOTR_V1
27
+
28
+ class DETR(nn.Module):
29
+ """ This is the DETR module that performs object detection """
30
+ def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False):
31
+ """ Initializes the model.
32
+ Parameters:
33
+ backbone: torch module of the backbone to be used. See backbone.py
34
+ transformer: torch module of the transformer architecture. See transformer.py
35
+ num_classes: number of object classes
36
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
37
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
38
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
39
+ """
40
+ super().__init__()
41
+ self.num_queries = num_queries
42
+ self.transformer = transformer
43
+ hidden_dim = transformer.d_model
44
+ self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
45
+ self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
46
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
47
+ self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)
48
+ self.backbone = backbone
49
+ self.aux_loss = aux_loss
50
+
51
+ def forward(self, samples: NestedTensor):
52
+ """ The forward expects a NestedTensor, which consists of:
53
+ - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
54
+ - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
55
+ It returns a dict with the following elements:
56
+ - "pred_logits": the classification logits (including no-object) for all queries.
57
+ Shape= [batch_size x num_queries x (num_classes + 1)]
58
+ - "pred_boxes": The normalized boxes coordinates for all queries, represented as
59
+ (center_x, center_y, height, width). These values are normalized in [0, 1],
60
+ relative to the size of each individual image (disregarding possible padding).
61
+ See PostProcess for information on how to retrieve the unnormalized bounding box.
62
+ - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
63
+ dictionnaries containing the two above keys for each decoder layer.
64
+ """
65
+ if isinstance(samples, (list, torch.Tensor)):
66
+ samples = nested_tensor_from_tensor_list(samples)
67
+ features, pos = self.backbone(samples)
68
+
69
+ src, mask = features[-1].decompose()
70
+ assert mask is not None
71
+ hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
72
+
73
+ outputs_class = self.class_embed(hs)
74
+ outputs_coord = self.bbox_embed(hs).sigmoid()
75
+ out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
76
+ if self.aux_loss:
77
+ out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
78
+
79
+ return out
80
+
81
+ @torch.jit.unused
82
+ def _set_aux_loss(self, outputs_class, outputs_coord):
83
+ # this is a workaround to make torchscript happy, as torchscript
84
+ # doesn't support dictionary with non-homogeneous values, such
85
+ # as a dict having both a Tensor and a list.
86
+ return [{'pred_logits': a, 'pred_boxes': b}
87
+ for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
88
+
89
+
90
+ def build(args):
91
+ device = torch.device(args.device)
92
+
93
+ backbone = build_backbone(args)
94
+
95
+ transformer = build_transformer(args)
96
+
97
+ model = DETR(
98
+ backbone,
99
+ transformer,
100
+ num_classes=args.num_classes,
101
+ num_queries=args.num_queries,
102
+ aux_loss=args.aux_loss,
103
+ )
104
+
105
+ matcher = build_matcher(args)
106
+ weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef}
107
+ weight_dict['loss_giou'] = args.giou_loss_coef
108
+
109
+ # TODO this is a hack
110
+ if args.aux_loss:
111
+ aux_weight_dict = {}
112
+ for i in range(args.dec_layers - 1):
113
+ aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
114
+ weight_dict.update(aux_weight_dict)
115
+
116
+ losses = ['labels', 'boxes', 'cardinality'] if args.frozen_weights is None else []
117
+ if args.HOIDet:
118
+ hoi_matcher = build_hoi_matcher(args)
119
+ hoi_losses = []
120
+ hoi_losses.append('pair_labels')
121
+ hoi_losses.append('pair_actions')
122
+ if args.dataset_file == 'hico-det': hoi_losses.append('pair_targets')
123
+
124
+ hoi_weight_dict={}
125
+ hoi_weight_dict['loss_hidx'] = args.hoi_idx_loss_coef
126
+ hoi_weight_dict['loss_oidx'] = args.hoi_idx_loss_coef
127
+ hoi_weight_dict['loss_h_consistency'] = args.hoi_idx_consistency_loss_coef
128
+ hoi_weight_dict['loss_o_consistency'] = args.hoi_idx_consistency_loss_coef
129
+ hoi_weight_dict['loss_act'] = args.hoi_act_loss_coef
130
+ hoi_weight_dict['loss_act_consistency'] = args.hoi_act_consistency_loss_coef
131
+ if args.dataset_file == 'hico-det':
132
+ hoi_weight_dict['loss_tgt'] = args.hoi_tgt_loss_coef
133
+ hoi_weight_dict['loss_tgt_consistency'] = args.hoi_tgt_consistency_loss_coef
134
+ if args.hoi_aux_loss:
135
+ hoi_aux_weight_dict = {}
136
+ for i in range(args.hoi_dec_layers):
137
+ hoi_aux_weight_dict.update({k + f'_{i}': v for k, v in hoi_weight_dict.items()})
138
+ hoi_weight_dict.update(hoi_aux_weight_dict)
139
+
140
+ criterion = SetCriterion(args.num_classes, matcher=matcher, weight_dict=hoi_weight_dict,
141
+ eos_coef=args.eos_coef, losses=losses, num_actions=args.num_actions,
142
+ HOI_losses=hoi_losses, HOI_matcher=hoi_matcher, args=args)
143
+
144
+ interaction_transformer = build_hoi_transformer(args) # if (args.share_enc and args.pretrained_dec) else None
145
+
146
+ kwargs = {}
147
+ if args.dataset_file == 'hico-det': kwargs['return_obj_class'] = args.valid_obj_ids
148
+ if args.sep_enc_forward:
149
+ model = HOTR_V1(
150
+ detr=model,
151
+ num_hoi_queries=args.num_hoi_queries,
152
+ num_actions=args.num_actions,
153
+ interaction_transformer=interaction_transformer,
154
+ augpath_name = args.augpath_name,
155
+ share_dec_param = args.share_dec_param,
156
+ stop_grad_stage = args.stop_grad_stage,
157
+ freeze_detr=(args.frozen_weights is not None),
158
+ share_enc=args.share_enc,
159
+ pretrained_dec=args.pretrained_dec,
160
+ temperature=args.temperature,
161
+ hoi_aux_loss=args.hoi_aux_loss,
162
+ **kwargs # only return verb class for HICO-DET dataset
163
+ )
164
+ else:
165
+ model = HOTR(
166
+ detr=model,
167
+ num_hoi_queries=args.num_hoi_queries,
168
+ num_actions=args.num_actions,
169
+ interaction_transformer=interaction_transformer,
170
+ augpath_name = args.augpath_name,
171
+ share_dec_param = args.share_dec_param,
172
+ stop_grad_stage = args.stop_grad_stage,
173
+ freeze_detr=(args.frozen_weights is not None),
174
+ share_enc=args.share_enc,
175
+ pretrained_dec=args.pretrained_dec,
176
+ temperature=args.temperature,
177
+ hoi_aux_loss=args.hoi_aux_loss,
178
+ **kwargs # only return verb class for HICO-DET dataset
179
+ )
180
+ postprocessors = {'hoi': PostProcess(args.HOIDet)}
181
+ else:
182
+ criterion = SetCriterion(args.num_classes, matcher=matcher, weight_dict=weight_dict,
183
+ eos_coef=args.eos_coef, losses=losses)
184
+ postprocessors = {'bbox': PostProcess(args.HOIDet)}
185
+ criterion.to(device)
186
+
187
+ return model, criterion, postprocessors
hotr/models/detr_matcher.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Modules to compute the matching cost and solve the corresponding LSAP.
4
+ """
5
+ import torch
6
+ from scipy.optimize import linear_sum_assignment
7
+ from torch import nn
8
+
9
+ from hotr.util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
10
+
11
+
12
+ class HungarianMatcher(nn.Module):
13
+ """This class computes an assignment between the targets and the predictions of the network
14
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
15
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
16
+ while the others are un-matched (and thus treated as non-objects).
17
+ """
18
+
19
+ def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
20
+ """Creates the matcher
21
+ Params:
22
+ cost_class: This is the relative weight of the classification error in the matching cost
23
+ cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
24
+ cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
25
+ """
26
+ super().__init__()
27
+ self.cost_class = cost_class
28
+ self.cost_bbox = cost_bbox
29
+ self.cost_giou = cost_giou
30
+ assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
31
+
32
+ @torch.no_grad()
33
+ def forward(self, outputs, targets):
34
+ """ Performs the matching
35
+ Params:
36
+ outputs: This is a dict that contains at least these entries:
37
+ "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
38
+ "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
39
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
40
+ "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
41
+ objects in the target) containing the class labels
42
+ "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
43
+ Returns:
44
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
45
+ - index_i is the indices of the selected predictions (in order)
46
+ - index_j is the indices of the corresponding selected targets (in order)
47
+ For each batch element, it holds:
48
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
49
+ """
50
+ bs, num_queries = outputs["pred_logits"].shape[:2]
51
+
52
+ # We flatten to compute the cost matrices in a batch
53
+ out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
54
+ out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
55
+
56
+ # Also concat the target labels and boxes
57
+ tgt_ids = torch.cat([v["labels"] for v in targets])
58
+ tgt_bbox = torch.cat([v["boxes"] for v in targets])
59
+
60
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
61
+ # but approximate it in 1 - proba[target class].
62
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
63
+ cost_class = -out_prob[:, tgt_ids]
64
+
65
+ # Compute the L1 cost between boxes
66
+ cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
67
+
68
+ # Compute the giou cost betwen boxes
69
+ cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
70
+
71
+ # Final cost matrix
72
+ C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
73
+ C = C.view(bs, num_queries, -1).cpu()
74
+
75
+ sizes = [len(v["boxes"]) for v in targets]
76
+ indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
77
+ return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
78
+
79
+
80
+ def build_matcher(args):
81
+ return HungarianMatcher(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou)
hotr/models/feed_forward.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ from torch import nn
3
+
4
+ class MLP(nn.Module):
5
+ """ Very simple multi-layer perceptron (also called FFN)"""
6
+
7
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
8
+ super().__init__()
9
+ self.num_layers = num_layers
10
+ h = [hidden_dim] * (num_layers - 1)
11
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
12
+
13
+ def forward(self, x):
14
+ for i, layer in enumerate(self.layers):
15
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
16
+ return x
hotr/models/hotr.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # HOTR official code : hotr/models/hotr.py
3
+ # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4
+ # ------------------------------------------------------------------------
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import copy
9
+ import time
10
+ import datetime
11
+
12
+ from hotr.util.misc import NestedTensor, nested_tensor_from_tensor_list
13
+ from .feed_forward import MLP
14
+
15
+ class HOTR(nn.Module):
16
+ def __init__(self, detr,
17
+ num_hoi_queries,
18
+ num_actions,
19
+ interaction_transformer,
20
+ augpath_name,
21
+ share_dec_param,
22
+ stop_grad_stage,
23
+ freeze_detr,
24
+ share_enc,
25
+ pretrained_dec,
26
+ temperature,
27
+ hoi_aux_loss,
28
+ return_obj_class=None):
29
+ super().__init__()
30
+
31
+ # * Instance Transformer ---------------
32
+ self.detr = detr
33
+ if freeze_detr:
34
+ # if this flag is given, freeze the object detection related parameters of DETR
35
+ for p in self.parameters():
36
+ p.requires_grad_(False)
37
+ hidden_dim = detr.transformer.d_model
38
+ # --------------------------------------
39
+
40
+ # * Interaction Transformer -----------------------------------------
41
+ self.num_queries = num_hoi_queries
42
+ self.query_embed = nn.Embedding(self.num_queries, hidden_dim)
43
+ self.H_Pointer_embed = MLP(hidden_dim, hidden_dim, hidden_dim, 3)
44
+ self.O_Pointer_embed = MLP(hidden_dim, hidden_dim, hidden_dim, 3)
45
+ self.action_embed = nn.Linear(hidden_dim, num_actions+1)
46
+ # --------------------------------------------------------------------
47
+
48
+
49
+ # * HICO-DET FFN heads ---------------------------------------------
50
+ self.return_obj_class = (return_obj_class is not None)
51
+ if return_obj_class: self._valid_obj_ids = return_obj_class + [return_obj_class[-1]+1]
52
+ # ------------------------------------------------------------------
53
+ # * Transformer Options ---------------------------------------------
54
+ self.interaction_transformer = interaction_transformer
55
+
56
+ if share_enc: # share encoder
57
+ self.interaction_transformer.encoder = detr.transformer.encoder
58
+
59
+ if pretrained_dec: # free variables for interaction decoder
60
+ self.interaction_transformer.decoder = copy.deepcopy(detr.transformer.decoder)
61
+ for p in self.interaction_transformer.decoder.parameters():
62
+ p.requires_grad_(True)
63
+ # ---------------------------------------------------------------------
64
+ #Augmented paths
65
+
66
+ self.aug_paths = augpath_name
67
+
68
+ if 'p2' in augpath_name:
69
+ if not share_dec_param:
70
+ self.xtoHO_interaction_decoder = copy.deepcopy(self.interaction_transformer.decoder)
71
+ self.HOtoI_interaction_decoder = copy.deepcopy(self.interaction_transformer.decoder)
72
+ else:
73
+ self.xtoHO_interaction_decoder = self.interaction_transformer.decoder
74
+ self.HOtoI_interaction_decoder = self.interaction_transformer.decoder
75
+
76
+ self.query_embed_HOtoI = nn.Embedding(self.num_queries, hidden_dim)
77
+ self.query_embed_HOtoI2 = nn.Embedding(self.num_queries, hidden_dim)
78
+ self.H_Pointer_embed_HOtoI = MLP(hidden_dim, hidden_dim, hidden_dim, 3)
79
+ self.O_Pointer_embed_HOtoI = MLP(hidden_dim, hidden_dim, hidden_dim, 3)
80
+ self.action_embed_HOtoI = nn.Linear(hidden_dim, num_actions+1)
81
+
82
+ if 'p3' in augpath_name:
83
+ if not share_dec_param:
84
+ self.xtoHI_interaction_decoder = copy.deepcopy(self.interaction_transformer.decoder)
85
+ self.HItoO_interaction_decoder = copy.deepcopy(self.interaction_transformer.decoder)
86
+ else:
87
+ self.xtoHI_interaction_decoder = self.interaction_transformer.decoder
88
+ self.HItoO_interaction_decoder = self.interaction_transformer.decoder
89
+
90
+ self.query_embed_HItoO = nn.Embedding(self.num_queries, hidden_dim)
91
+ self.query_embed_HItoO2 = nn.Embedding(self.num_queries, hidden_dim)
92
+ self.H_Pointer_embed_HItoO = MLP(hidden_dim, hidden_dim, hidden_dim, 3)
93
+ self.O_Pointer_embed_HItoO = MLP(hidden_dim, hidden_dim, hidden_dim, 3)
94
+ self.action_embed_HItoO = nn.Linear(hidden_dim, num_actions+1)
95
+
96
+ if 'p4' in augpath_name:
97
+ if not share_dec_param:
98
+ self.xtoOI_interaction_decoder = copy.deepcopy(self.interaction_transformer.decoder)
99
+ self.OItoH_interaction_decoder = copy.deepcopy(self.interaction_transformer.decoder)
100
+ else:
101
+ self.xtoOI_interaction_decoder = self.interaction_transformer.decoder
102
+ self.OItoH_interaction_decoder = self.interaction_transformer.decoder
103
+
104
+ self.query_embed_OItoH = nn.Embedding(self.num_queries, hidden_dim)
105
+ self.query_embed_OItoH2 = nn.Embedding(self.num_queries, hidden_dim)
106
+ self.H_Pointer_embed_OItoH = MLP(hidden_dim, hidden_dim, hidden_dim, 3)
107
+ self.O_Pointer_embed_OItoH = MLP(hidden_dim, hidden_dim, hidden_dim, 3)
108
+ self.action_embed_OItoH = nn.Linear(hidden_dim, num_actions+1)
109
+
110
+ self.stop_grad_stage = stop_grad_stage
111
+
112
+ # * Loss Options -------------------
113
+ self.tau = temperature
114
+ self.hoi_aux_loss = hoi_aux_loss
115
+ # ----------------------------------
116
+
117
+ def forward(self, samples: NestedTensor):
118
+ if isinstance(samples, (list, torch.Tensor)):
119
+ samples = nested_tensor_from_tensor_list(samples)
120
+
121
+ # >>>>>>>>>>>> BACKBONE LAYERS <<<<<<<<<<<<<<<
122
+ features, pos = self.detr.backbone(samples)
123
+ bs = features[-1].tensors.shape[0]
124
+ src, mask = features[-1].decompose()
125
+ assert mask is not None
126
+ # ----------------------------------------------
127
+
128
+ # >>>>>>>>>>>> OBJECT DETECTION LAYERS <<<<<<<<<<
129
+ start_time = time.time()
130
+ hs, memory = self.detr.transformer(self.detr.input_proj(src), mask, self.detr.query_embed.weight, pos[-1])
131
+ inst_repr = F.normalize(hs[-1], p=2, dim=2) # instance representations
132
+
133
+ # Prediction Heads for Object Detection
134
+ outputs_class = self.detr.class_embed(hs)
135
+ outputs_coord = self.detr.bbox_embed(hs).sigmoid()
136
+ object_detection_time = time.time() - start_time
137
+ # -----------------------------------------------
138
+
139
+ # >>>>>>>>>>>> HOI DETECTION LAYERS <<<<<<<<<<<<<<<
140
+ start_time = time.time()
141
+ assert hasattr(self, 'interaction_transformer'), "Missing Interaction Transformer."
142
+ H_Pointer_reprs_bag,O_Pointer_reprs_bag,outputs_action=[],[],[]
143
+ # main path P1
144
+ interaction_hs= self.interaction_transformer(self.detr.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] # interaction representations
145
+ H_Pointer_reprs_bag.append(F.normalize(self.H_Pointer_embed(interaction_hs), p=2, dim=-1))
146
+ O_Pointer_reprs_bag.append(F.normalize(self.O_Pointer_embed(interaction_hs), p=2, dim=-1))
147
+ outputs_action.append(self.action_embed(interaction_hs))
148
+
149
+ if len(self.aug_paths)!=0:
150
+ pos_aug = pos[-1].flatten(2).permute(2, 0, 1)
151
+ mask_aug = mask.flatten(1)
152
+
153
+ # P2 (x->HO->I)
154
+ if 'p2' in self.aug_paths:
155
+ tgt_2 = torch.zeros_like(self.query_embed_HOtoI.weight.unsqueeze(1).repeat(1, bs, 1))
156
+ hs_HOtoI = self.xtoHO_interaction_decoder(tgt_2,memory,memory_key_padding_mask=mask_aug, pos=pos_aug, query_pos=self.query_embed_HOtoI.weight.unsqueeze(1).repeat(1, bs, 1)).transpose(1,2)
157
+ tgt_HOtoI = hs_HOtoI.transpose(1,2)[-1] if not self.stop_grad_stage else hs_HOtoI.clone().detach().transpose(1,2)[-1]
158
+ hs2_HOtoI = self.HOtoI_interaction_decoder(tgt_HOtoI,memory,memory_key_padding_mask=mask_aug, pos=pos_aug, query_pos=self.query_embed_HOtoI2.weight.unsqueeze(1).repeat(1, bs, 1)).transpose(1,2)
159
+ H_Pointer_reprs_bag.append(F.normalize(self.H_Pointer_embed_HOtoI(hs_HOtoI), p=2, dim=-1))
160
+ O_Pointer_reprs_bag.append(F.normalize(self.O_Pointer_embed_HOtoI(hs_HOtoI), p=2, dim=-1))
161
+ outputs_action.append(self.action_embed_HOtoI(hs2_HOtoI))
162
+ # P3 (x->HI->O)
163
+ if 'p3' in self.aug_paths:
164
+ tgt_3 = torch.zeros_like(self.query_embed_HItoO.weight.unsqueeze(1).repeat(1, bs, 1))
165
+ hs_HItoO = self.xtoHI_interaction_decoder(tgt_3,memory,memory_key_padding_mask=mask_aug, pos=pos_aug, query_pos=self.query_embed_HItoO.weight.unsqueeze(1).repeat(1, bs, 1)).transpose(1,2)
166
+ tgt_HItoO = hs_HItoO.transpose(1,2)[-1] if not self.stop_grad_stage else hs_HItoO.clone().detach().transpose(1,2)[-1]
167
+ hs2_HItoO = self.HItoO_interaction_decoder(tgt_HItoO,memory,memory_key_padding_mask=mask_aug, pos=pos_aug, query_pos=self.query_embed_HItoO2.weight.unsqueeze(1).repeat(1, bs, 1)).transpose(1,2)
168
+ H_Pointer_reprs_bag.append(F.normalize(self.H_Pointer_embed_HItoO(hs_HItoO), p=2, dim=-1))
169
+ O_Pointer_reprs_bag.append(F.normalize(self.O_Pointer_embed_HItoO(hs2_HItoO), p=2, dim=-1))
170
+ outputs_action.append(self.action_embed_HItoO(hs_HItoO))
171
+ # P4 (x->OI->H)
172
+ if 'p4' in self.aug_paths:
173
+ tgt_4 = torch.zeros_like(self.query_embed_OItoH.weight.unsqueeze(1).repeat(1, bs, 1))
174
+ hs_OItoH = self.xtoOI_interaction_decoder(tgt_3,memory,memory_key_padding_mask=mask_aug, pos=pos_aug, query_pos=self.query_embed_OItoH.weight.unsqueeze(1).repeat(1, bs, 1)).transpose(1,2)
175
+ tgt_OItoH = hs_OItoH.transpose(1,2)[-1] if not self.stop_grad_stage else hs_OItoH.clone().detach().transpose(1,2)[-1]
176
+ hs2_OItoH = self.OItoH_interaction_decoder(tgt_OItoH,memory,memory_key_padding_mask=mask_aug, pos=pos_aug, query_pos=self.query_embed_OItoH2.weight.unsqueeze(1).repeat(1, bs, 1)).transpose(1,2)
177
+ H_Pointer_reprs_bag.append(F.normalize(self.H_Pointer_embed_OItoH(hs2_OItoH), p=2, dim=-1))
178
+ O_Pointer_reprs_bag.append(F.normalize(self.O_Pointer_embed_OItoH(hs_OItoH), p=2, dim=-1))
179
+ outputs_action.append(self.action_embed_OItoH(hs_OItoH))
180
+
181
+ inst_repr_all=inst_repr.transpose(1,2).repeat(1+len(self.aug_paths),1,1)
182
+
183
+ H_Pointer_reprs_bag=torch.cat(H_Pointer_reprs_bag,1)
184
+ O_Pointer_reprs_bag=torch.cat(O_Pointer_reprs_bag,1)
185
+
186
+ outputs_hidx = [(torch.bmm(H_Pointer_repr, inst_repr_all)) / self.tau for H_Pointer_repr in H_Pointer_reprs_bag] #(dec_layer,(1+len(aug))*bs,dec_q,hidden_dim)
187
+ outputs_oidx = [(torch.bmm(O_Pointer_repr, inst_repr_all)) / self.tau for O_Pointer_repr in O_Pointer_reprs_bag]
188
+
189
+ outputs_action=torch.stack(outputs_action,dim=2) #(dec_layer,bs,1+#aug,dec_q,#action)
190
+
191
+ # --------------------------------------------------
192
+ hoi_detection_time = time.time() - start_time
193
+ hoi_recognition_time = max(hoi_detection_time - object_detection_time, 0)
194
+ # -------------------------------------------------------------------
195
+
196
+ # [Target Classification]
197
+ if self.return_obj_class:
198
+ detr_logits = outputs_class[-1, ..., self._valid_obj_ids]
199
+ o_indices = [output_oidx.max(-1)[-1].view(1+len(self.aug_paths),bs,self.num_queries).transpose(0,1) for output_oidx in outputs_oidx]
200
+ obj_logit_stack = [torch.stack([detr_logits[batch_, o_idx, :] for batch_, o_idc in enumerate(o_indice) for o_idx in o_idc], 0) for o_indice in o_indices]
201
+ outputs_obj_class = obj_logit_stack
202
+
203
+ out = {
204
+ "pred_logits": outputs_class[-1],
205
+ "pred_boxes": outputs_coord[-1],
206
+ "pred_hidx": outputs_hidx[-1],
207
+ "pred_oidx": outputs_oidx[-1],
208
+ "pred_actions": outputs_action[-1],
209
+ "hoi_recognition_time": hoi_recognition_time,
210
+ }
211
+
212
+ if self.return_obj_class: out["pred_obj_logits"] = outputs_obj_class[-1]
213
+ # import pdb;pdb.set_trace()
214
+ if self.hoi_aux_loss: # auxiliary loss
215
+ out['hoi_aux_outputs'] = \
216
+ self._set_aux_loss_with_tgt(outputs_class, outputs_coord, outputs_hidx, outputs_oidx, outputs_action, outputs_obj_class) \
217
+ if self.return_obj_class else \
218
+ self._set_aux_loss(outputs_class, outputs_coord, outputs_hidx, outputs_oidx, outputs_action)
219
+
220
+ return out
221
+
222
+ @torch.jit.unused
223
+ def _set_aux_loss(self, outputs_class, outputs_coord, outputs_hidx, outputs_oidx, outputs_action):
224
+ return [{'pred_logits': a, 'pred_boxes': b, 'pred_hidx': c, 'pred_oidx': d, 'pred_actions': e}
225
+ for a, b, c, d, e in zip(
226
+ outputs_class[-1:].repeat((outputs_action.shape[0], 1, 1, 1)),
227
+ outputs_coord[-1:].repeat((outputs_action.shape[0], 1, 1, 1)),
228
+ outputs_hidx[:-1],
229
+ outputs_oidx[:-1],
230
+ outputs_action[:-1])]
231
+
232
+ @torch.jit.unused
233
+ def _set_aux_loss_with_tgt(self, outputs_class, outputs_coord, outputs_hidx, outputs_oidx, outputs_action, outputs_tgt):
234
+ return [{'pred_logits': a, 'pred_boxes': b, 'pred_hidx': c, 'pred_oidx': d, 'pred_actions': e, 'pred_obj_logits': f}
235
+ for a, b, c, d, e, f in zip(
236
+ outputs_class[-1:].repeat((outputs_action.shape[0], 1, 1, 1)),
237
+ outputs_coord[-1:].repeat((outputs_action.shape[0], 1, 1, 1)),
238
+ outputs_hidx[:-1],
239
+ outputs_oidx[:-1],
240
+ outputs_action[:-1],
241
+ outputs_tgt[:-1])]
hotr/models/hotr_matcher.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # HOTR official code : hotr/models/hotr_matcher.py
3
+ # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4
+ # ------------------------------------------------------------------------
5
+ import torch
6
+ from scipy.optimize import linear_sum_assignment
7
+ from torch import nn
8
+
9
+ from hotr.util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
10
+
11
+ import hotr.util.misc as utils
12
+ import wandb
13
+
14
+ class HungarianPairMatcher(nn.Module):
15
+ def __init__(self, args):
16
+ """Creates the matcher
17
+ Params:
18
+ cost_action: This is the relative weight of the multi-label action classification error in the matching cost
19
+ cost_hbox: This is the relative weight of the classification error for human idx in the matching cost
20
+ cost_obox: This is the relative weight of the classification error for object idx in the matching cost
21
+ """
22
+ super().__init__()
23
+ self.cost_action = args.set_cost_act
24
+ self.cost_hbox = self.cost_obox = args.set_cost_idx
25
+ self.cost_target = args.set_cost_tgt
26
+ self.log_printer = args.wandb
27
+ self.is_vcoco = (args.dataset_file == 'vcoco')
28
+ self.is_hico = (args.dataset_file == 'hico-det')
29
+ if self.is_vcoco:
30
+ self.valid_ids = args.valid_ids
31
+ self.invalid_ids = args.invalid_ids
32
+ assert self.cost_action != 0 or self.cost_hbox != 0 or self.cost_obox != 0, "all costs cant be 0"
33
+
34
+ def reduce_redundant_gt_box(self, tgt_bbox, indices):
35
+ """Filters redundant Ground-Truth Bounding Boxes
36
+ Due to random crop augmentation, there exists cases where there exists
37
+ multiple redundant labels for the exact same bounding box and object class.
38
+ This function deals with the redundant labels for smoother HOTR training.
39
+ """
40
+ tgt_bbox_unique, map_idx, idx_cnt = torch.unique(tgt_bbox, dim=0, return_inverse=True, return_counts=True)
41
+
42
+ k_idx, bbox_idx = indices
43
+
44
+ triggered = False
45
+ if (len(tgt_bbox) != len(tgt_bbox_unique)):
46
+ map_dict = {k: v for k, v in enumerate(map_idx)}
47
+ map_bbox2kidx = {int(bbox_id): k_id for bbox_id, k_id in zip(bbox_idx, k_idx)}
48
+
49
+ bbox_lst, k_lst = [], []
50
+ for bbox_id in bbox_idx:
51
+ if map_dict[int(bbox_id)] not in bbox_lst:
52
+ bbox_lst.append(map_dict[int(bbox_id)])
53
+ k_lst.append(map_bbox2kidx[int(bbox_id)])
54
+ bbox_idx = torch.tensor(bbox_lst)
55
+ k_idx = torch.tensor(k_lst)
56
+ tgt_bbox_res = tgt_bbox_unique
57
+ else:
58
+ tgt_bbox_res = tgt_bbox
59
+
60
+ bbox_idx = bbox_idx.to(tgt_bbox.device)
61
+
62
+ return tgt_bbox_res, k_idx, bbox_idx
63
+
64
+ @torch.no_grad()
65
+ def forward(self, outputs, targets, indices, log=False):
66
+ assert "pred_actions" in outputs, "There is no action output for pair matching"
67
+ num_obj_queries = outputs["pred_boxes"].shape[1]
68
+ bs,num_path, num_queries = outputs["pred_actions"].shape[:3]
69
+ detr_query_num = outputs["pred_logits"].shape[1] \
70
+ if (outputs["pred_oidx"].shape[-1] == (outputs["pred_logits"].shape[1] + 1)) else -1
71
+
72
+ return_list = []
73
+ if self.log_printer and log:
74
+ log_dict = {'h_cost': [], 'o_cost': [], 'act_cost': []}
75
+ if self.is_hico: log_dict['tgt_cost'] = []
76
+
77
+ for batch_idx in range(bs):
78
+ tgt_bbox = targets[batch_idx]["boxes"] # (num_boxes, 4)
79
+ tgt_cls = targets[batch_idx]["labels"] # (num_boxes)
80
+
81
+ if self.is_vcoco:
82
+ targets[batch_idx]["pair_actions"][:, self.invalid_ids] = 0
83
+ keep_idx = (targets[batch_idx]["pair_actions"].sum(dim=-1) != 0)
84
+ targets[batch_idx]["pair_boxes"] = targets[batch_idx]["pair_boxes"][keep_idx]
85
+ targets[batch_idx]["pair_actions"] = targets[batch_idx]["pair_actions"][keep_idx]
86
+ targets[batch_idx]["pair_targets"] = targets[batch_idx]["pair_targets"][keep_idx]
87
+
88
+ tgt_pbox = targets[batch_idx]["pair_boxes"] # (num_pair_boxes, 8)
89
+ tgt_act = targets[batch_idx]["pair_actions"] # (num_pair_boxes, 29)
90
+ tgt_tgt = targets[batch_idx]["pair_targets"] # (num_pair_boxes)
91
+
92
+ tgt_hbox = tgt_pbox[:, :4] # (num_pair_boxes, 4)
93
+ tgt_obox = tgt_pbox[:, 4:] # (num_pair_boxes, 4)
94
+ elif self.is_hico:
95
+ tgt_act = targets[batch_idx]["pair_actions"] # (num_pair_boxes, 117)
96
+ tgt_tgt = targets[batch_idx]["pair_targets"] # (num_pair_boxes)
97
+
98
+ tgt_hbox = targets[batch_idx]["sub_boxes"] # (num_pair_boxes, 4)
99
+ tgt_obox = targets[batch_idx]["obj_boxes"] # (num_pair_boxes, 4)
100
+
101
+ # find which gt boxes match the h, o boxes in the pair
102
+ if self.is_vcoco:
103
+ hbox_with_cls = torch.cat([tgt_hbox, torch.ones((tgt_hbox.shape[0], 1)).to(tgt_hbox.device)], dim=1)
104
+ elif self.is_hico:
105
+ hbox_with_cls = torch.cat([tgt_hbox, torch.zeros((tgt_hbox.shape[0], 1)).to(tgt_hbox.device)], dim=1)
106
+ obox_with_cls = torch.cat([tgt_obox, tgt_tgt.unsqueeze(-1)], dim=1)
107
+ obox_with_cls[obox_with_cls[:, :4].sum(dim=1) == -4, -1] = -1 # turn the class of occluded objects to -1
108
+
109
+ bbox_with_cls = torch.cat([tgt_bbox, tgt_cls.unsqueeze(-1)], dim=1)
110
+ bbox_with_cls, k_idx, bbox_idx = self.reduce_redundant_gt_box(bbox_with_cls, indices[batch_idx])
111
+ bbox_with_cls = torch.cat((bbox_with_cls, torch.as_tensor([-1.]*5).unsqueeze(0).to(tgt_cls.device)), dim=0)
112
+
113
+ cost_hbox = torch.cdist(hbox_with_cls, bbox_with_cls, p=1)
114
+ cost_obox = torch.cdist(obox_with_cls, bbox_with_cls, p=1)
115
+
116
+ # find which gt boxes matches which prediction in K
117
+ h_match_indices = torch.nonzero(cost_hbox == 0, as_tuple=False) # (num_hbox, num_boxes)
118
+ o_match_indices = torch.nonzero(cost_obox == 0, as_tuple=False) # (num_obox, num_boxes)
119
+
120
+ tgt_hids, tgt_oids = [], []
121
+
122
+ # obtain ground truth indices for h
123
+ if len(h_match_indices) != len(o_match_indices):
124
+ import pdb; pdb.set_trace()
125
+
126
+ for h_match_idx, o_match_idx in zip(h_match_indices, o_match_indices):
127
+ hbox_idx, H_bbox_idx = h_match_idx
128
+ obox_idx, O_bbox_idx = o_match_idx
129
+ if O_bbox_idx == (len(bbox_with_cls)-1): # if the object class is -1
130
+ O_bbox_idx = H_bbox_idx # happens in V-COCO, the target object may not appear
131
+
132
+ GT_idx_for_H = (bbox_idx == H_bbox_idx).nonzero(as_tuple=False).squeeze(-1)
133
+ query_idx_for_H = k_idx[GT_idx_for_H]
134
+ tgt_hids.append(query_idx_for_H)
135
+
136
+ GT_idx_for_O = (bbox_idx == O_bbox_idx).nonzero(as_tuple=False).squeeze(-1)
137
+ query_idx_for_O = k_idx[GT_idx_for_O]
138
+ tgt_oids.append(query_idx_for_O)
139
+
140
+ # check if empty
141
+ if len(tgt_hids) == 0: tgt_hids.append(torch.as_tensor([-1])) # we later ignore the label -1
142
+ if len(tgt_oids) == 0: tgt_oids.append(torch.as_tensor([-1])) # we later ignore the label -1
143
+
144
+ tgt_sum = (tgt_act.sum(dim=-1)).unsqueeze(0)
145
+ flag = False
146
+ if tgt_act.shape[0] == 0:
147
+ tgt_act = torch.zeros((1, tgt_act.shape[1])).to(targets[batch_idx]["pair_actions"].device)
148
+ targets[batch_idx]["pair_actions"] = torch.zeros((1, targets[batch_idx]["pair_actions"].shape[1])).to(targets[batch_idx]["pair_actions"].device)
149
+ if self.is_hico:
150
+ pad_tgt = -1 # outputs["pred_obj_logits"].shape[-1]-1
151
+ tgt_tgt = torch.tensor([pad_tgt]).to(targets[batch_idx]["pair_targets"])
152
+ targets[batch_idx]["pair_targets"] = torch.tensor([pad_tgt]).to(targets[batch_idx]["pair_targets"].device)
153
+ tgt_sum = (tgt_act.sum(dim=-1) + 1).unsqueeze(0)
154
+
155
+ # Concat target label
156
+ tgt_hids = torch.cat(tgt_hids).repeat(num_path)
157
+ tgt_oids = torch.cat(tgt_oids).repeat(num_path)
158
+ # import pdb;pdb.set_trace()
159
+ outputs_hidx=outputs["pred_hidx"].view(num_path,bs,num_queries,-1).transpose(0,1).flatten(1,2)
160
+ outputs_oidx=outputs["pred_oidx"].view(num_path,bs,num_queries,-1).transpose(0,1).flatten(1,2)
161
+
162
+ outputs_action=outputs["pred_actions"].view(bs,num_path*num_queries,-1)
163
+ out_hprob = outputs_hidx[batch_idx].softmax(-1)
164
+ out_oprob = outputs_oidx[batch_idx].softmax(-1)
165
+ out_act = outputs_action[batch_idx].clone()
166
+ if self.is_vcoco: out_act[..., self.invalid_ids] = 0
167
+ if self.is_hico:
168
+ outputs_obj_logits=outputs["pred_obj_logits"].view(bs,num_path,num_queries,-1).view(bs,num_path*num_queries,-1)
169
+ out_tgt = outputs_obj_logits[batch_idx].softmax(-1)
170
+ out_tgt[..., -1] = 0 # don't get cost for no-object
171
+
172
+ tgt_act = torch.cat([tgt_act, torch.zeros(tgt_act.shape[0]).unsqueeze(-1).to(tgt_act.device)], dim=-1).repeat(num_path,1)
173
+
174
+ cost_hclass = -out_hprob[:, tgt_hids] # [batch_size * num_queries, detr.num_queries+1]
175
+ cost_oclass = -out_oprob[:, tgt_oids] # [batch_size * num_queries, detr.num_queries+1]
176
+ # import pdb;pdb.set_trace()
177
+ cost_pos_act = (-torch.matmul(out_act, tgt_act.t().float())) / tgt_sum.repeat(1,num_path)
178
+ cost_neg_act = (torch.matmul(out_act, (~tgt_act.bool()).type(torch.int64).t().float())) / (~tgt_act.bool()).type(torch.int64).sum(dim=-1).unsqueeze(0)
179
+ cost_action = cost_pos_act + cost_neg_act
180
+
181
+ h_cost = self.cost_hbox * cost_hclass
182
+ o_cost = self.cost_obox * cost_oclass
183
+
184
+ act_cost = self.cost_action * cost_action
185
+
186
+ C = h_cost + o_cost + act_cost
187
+
188
+ if self.is_hico:
189
+ cost_target = -out_tgt[:, tgt_tgt.repeat(num_path)]
190
+ tgt_cost = self.cost_target * cost_target
191
+ C += tgt_cost
192
+ C = C.view(num_path,num_queries, -1).cpu()
193
+
194
+ sizes = [len(tgt_hids)//num_path]*num_path
195
+ hoi_indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
196
+ return_list.append([(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in hoi_indices])
197
+ # import pdb;pdb.set_trace()
198
+ targets[batch_idx]["h_labels"] = tgt_hids.to(tgt_hbox.device)
199
+ targets[batch_idx]["o_labels"] = tgt_oids.to(tgt_obox.device)
200
+ log_act_cost = torch.zeros([1]).to(tgt_act.device) if tgt_act.shape[0] == 0 else act_cost.min(dim=0)[0].mean()
201
+
202
+ if self.log_printer and log:
203
+ log_dict['h_cost'].append(h_cost[:num_queries].min(dim=0)[0].mean())
204
+ log_dict['o_cost'].append(o_cost[:num_queries].min(dim=0)[0].mean())
205
+ log_dict['act_cost'].append(act_cost[:num_queries].min(dim=0)[0].mean())
206
+ if self.is_hico: log_dict['tgt_cost'].append(tgt_cost[:num_queries].min(dim=0)[0].mean())
207
+ if self.log_printer and log:
208
+ log_dict['h_cost'] = torch.stack(log_dict['h_cost']).mean()
209
+ log_dict['o_cost'] = torch.stack(log_dict['o_cost']).mean()
210
+ log_dict['act_cost'] = torch.stack(log_dict['act_cost']).mean()
211
+ if self.is_hico: log_dict['tgt_cost'] = torch.stack(log_dict['tgt_cost']).mean()
212
+ if utils.get_rank() == 0: wandb.log(log_dict)
213
+ return return_list, targets
214
+
215
+ def build_hoi_matcher(args):
216
+ return HungarianPairMatcher(args)
hotr/models/position_encoding.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Various positional encodings for the transformer.
4
+ """
5
+ import math
6
+ import torch
7
+ from torch import nn
8
+
9
+ from hotr.util.misc import NestedTensor
10
+
11
+
12
+ class PositionEmbeddingSine(nn.Module):
13
+ """
14
+ This is a more standard version of the position embedding, very similar to the one
15
+ used by the Attention is all you need paper, generalized to work on images.
16
+ """
17
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
18
+ super().__init__()
19
+ self.num_pos_feats = num_pos_feats
20
+ self.temperature = temperature
21
+ self.normalize = normalize
22
+ if scale is not None and normalize is False:
23
+ raise ValueError("normalize should be True if scale is passed")
24
+ if scale is None:
25
+ scale = 2 * math.pi
26
+ self.scale = scale
27
+
28
+ def forward(self, tensor_list: NestedTensor):
29
+ x = tensor_list.tensors
30
+ mask = tensor_list.mask
31
+ assert mask is not None
32
+ not_mask = ~mask
33
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
34
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
35
+ if self.normalize:
36
+ eps = 1e-6
37
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
38
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
39
+
40
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
41
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
42
+
43
+ pos_x = x_embed[:, :, :, None] / dim_t
44
+ pos_y = y_embed[:, :, :, None] / dim_t
45
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
46
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
47
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
48
+ return pos
49
+
50
+
51
+ class PositionEmbeddingLearned(nn.Module):
52
+ """
53
+ Absolute pos embedding, learned.
54
+ """
55
+ def __init__(self, num_pos_feats=256):
56
+ super().__init__()
57
+ self.row_embed = nn.Embedding(50, num_pos_feats)
58
+ self.col_embed = nn.Embedding(50, num_pos_feats)
59
+ self.reset_parameters()
60
+
61
+ def reset_parameters(self):
62
+ nn.init.uniform_(self.row_embed.weight)
63
+ nn.init.uniform_(self.col_embed.weight)
64
+
65
+ def forward(self, tensor_list: NestedTensor):
66
+ x = tensor_list.tensors
67
+ h, w = x.shape[-2:]
68
+ i = torch.arange(w, device=x.device)
69
+ j = torch.arange(h, device=x.device)
70
+ x_emb = self.col_embed(i)
71
+ y_emb = self.row_embed(j)
72
+ pos = torch.cat([
73
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
74
+ y_emb.unsqueeze(1).repeat(1, w, 1),
75
+ ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
76
+ return pos
77
+
78
+
79
+ def build_position_encoding(args):
80
+ N_steps = args.hidden_dim // 2
81
+ if args.position_embedding in ('v2', 'sine'):
82
+ # TODO find a better way of exposing other arguments
83
+ position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
84
+ elif args.position_embedding in ('v3', 'learned'):
85
+ position_embedding = PositionEmbeddingLearned(N_steps)
86
+ else:
87
+ raise ValueError(f"not supported {args.position_embedding}")
88
+
89
+ return position_embedding
hotr/models/post_process.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # HOTR official code : hotr/models/post_process.py
3
+ # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4
+ # ------------------------------------------------------------------------
5
+ import time
6
+ import copy
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+ from hotr.util import box_ops
11
+
12
+ class PostProcess(nn.Module):
13
+ """ This module converts the model's output into the format expected by the coco api"""
14
+ def __init__(self, HOIDet):
15
+ super().__init__()
16
+ self.HOIDet = HOIDet
17
+
18
+ @torch.no_grad()
19
+ def forward(self, outputs, target_sizes, threshold=0, dataset='coco',args=None):
20
+ """ Perform the computation
21
+ Parameters:
22
+ outputs: raw outputs of the model
23
+ target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
24
+ For evaluation, this must be the original image size (before any data augmentation)
25
+ For visualization, this should be the image size after data augment, but before padding
26
+ """
27
+ out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
28
+ num_path = 1+len(args.augpath_name)
29
+ path_id = args.path_id
30
+ assert len(out_logits) == len(target_sizes)
31
+ assert target_sizes.shape[1] == 2
32
+
33
+ prob = F.softmax(out_logits, -1)
34
+ scores, labels = prob[..., :-1].max(-1)
35
+
36
+ boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
37
+ img_h, img_w = target_sizes.unbind(1)
38
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
39
+ boxes = boxes * scale_fct[:, None, :]
40
+
41
+ # Preidction Branch for HOI detection
42
+ if self.HOIDet:
43
+ if dataset == 'vcoco':
44
+ """ Compute HOI triplet prediction score for V-COCO.
45
+ Our scoring function follows the implementation details of UnionDet.
46
+ """
47
+
48
+ out_time = outputs['hoi_recognition_time']
49
+ bss,q,hd=outputs['pred_hidx'].shape
50
+ start_time = time.time()
51
+ pair_actions = torch.sigmoid(outputs['pred_actions'][:,path_id,...])
52
+ h_prob = F.softmax(outputs['pred_hidx'].view(num_path,bss//num_path,q,hd)[path_id], -1)
53
+ h_idx_score, h_indices = h_prob.max(-1)
54
+
55
+ o_prob = F.softmax(outputs['pred_oidx'].view(num_path,bss//num_path,q,hd)[path_id], -1)
56
+ o_idx_score, o_indices = o_prob.max(-1)
57
+ hoi_recognition_time = (time.time() - start_time) + out_time
58
+ # import pdb;pdb.set_trace()
59
+ results = []
60
+ # iterate for batch size
61
+ for batch_idx, (s, l, b) in enumerate(zip(scores, labels, boxes)):
62
+ h_inds = (l == 1) & (s > threshold)
63
+ o_inds = (s > threshold)
64
+
65
+ h_box, h_cat = b[h_inds], s[h_inds]
66
+ o_box, o_cat = b[o_inds], s[o_inds]
67
+
68
+ # for scenario 1 in v-coco dataset
69
+ o_inds = torch.cat((o_inds, torch.ones(1).type(torch.bool).to(o_inds.device)))
70
+ o_box = torch.cat((o_box, torch.Tensor([0, 0, 0, 0]).unsqueeze(0).to(o_box.device)))
71
+
72
+ result_dict = {
73
+ 'h_box': h_box, 'h_cat': h_cat,
74
+ 'o_box': o_box, 'o_cat': o_cat,
75
+ 'scores': s, 'labels': l, 'boxes': b
76
+ }
77
+
78
+ h_inds_lst = (h_inds == True).nonzero(as_tuple=False).squeeze(-1)
79
+ o_inds_lst = (o_inds == True).nonzero(as_tuple=False).squeeze(-1)
80
+
81
+ K = boxes.shape[1]
82
+ n_act = pair_actions[batch_idx][:, :-1].shape[-1]
83
+ score = torch.zeros((n_act, K, K+1)).to(pair_actions[batch_idx].device)
84
+ sorted_score = torch.zeros((n_act, K, K+1)).to(pair_actions[batch_idx].device)
85
+ id_score = torch.zeros((K, K+1)).to(pair_actions[batch_idx].device)
86
+ # import pdb;pdb.set_trace()
87
+ # Score function
88
+ for hs, h_idx, os, o_idx, pair_action in zip(h_idx_score[batch_idx], h_indices[batch_idx], o_idx_score[batch_idx], o_indices[batch_idx], pair_actions[batch_idx]):
89
+ matching_score = (1-pair_action[-1]) # no interaction score
90
+ if h_idx == o_idx: o_idx = -1
91
+ if matching_score > id_score[h_idx, o_idx]:
92
+ id_score[h_idx, o_idx] = matching_score
93
+ sorted_score[:, h_idx, o_idx] = matching_score * pair_action[:-1]
94
+ score[:, h_idx, o_idx] += matching_score * pair_action[:-1]
95
+
96
+ score += sorted_score
97
+ score = score[:, h_inds, :]
98
+ score = score[:, :, o_inds]
99
+
100
+ result_dict.update({
101
+ 'pair_score': score,
102
+ 'hoi_recognition_time': hoi_recognition_time,
103
+ })
104
+
105
+ results.append(result_dict)
106
+
107
+ elif dataset == 'hico-det':
108
+ """ Compute HOI triplet prediction score for HICO-DET.
109
+ For HICO-DET, we follow the same scoring function but do not accumulate the results.
110
+ """
111
+
112
+ bss,q,hd=outputs['pred_hidx'].shape
113
+ out_time = outputs['hoi_recognition_time']
114
+ a,b,c=outputs['pred_obj_logits'].shape
115
+ start_time = time.time()
116
+ out_obj_logits, out_verb_logits = outputs['pred_obj_logits'].view(-1,num_path,b,c)[:,path_id,...], outputs['pred_actions'][:,path_id,...]
117
+ out_verb_logits = outputs['pred_actions'][:,path_id,...]
118
+
119
+ # actions
120
+ matching_scores = (1-out_verb_logits.sigmoid()[..., -1:]) #* (1-out_verb_logits.sigmoid()[..., 57:58])
121
+ verb_scores = out_verb_logits.sigmoid()[..., :-1] * matching_scores
122
+
123
+ # hbox, obox
124
+ outputs_hrepr, outputs_orepr = outputs['pred_hidx'].view(num_path,bss//num_path,q,hd)[path_id], outputs['pred_oidx'].view(num_path,bss//num_path,q,hd)[path_id]
125
+ obj_scores, obj_labels = F.softmax(out_obj_logits, -1)[..., :-1].max(-1)
126
+
127
+ h_prob = F.softmax(outputs_hrepr, -1)
128
+ h_idx_score, h_indices = h_prob.max(-1)
129
+
130
+ # targets
131
+ o_prob = F.softmax(outputs_orepr, -1)
132
+ o_idx_score, o_indices = o_prob.max(-1)
133
+ hoi_recognition_time = (time.time() - start_time) + out_time
134
+
135
+ # hidx, oidx
136
+ sub_boxes, obj_boxes = [], []
137
+ for batch_id, (box, h_idx, o_idx) in enumerate(zip(boxes, h_indices, o_indices)):
138
+ sub_boxes.append(box[h_idx, :])
139
+ obj_boxes.append(box[o_idx, :])
140
+ sub_boxes = torch.stack(sub_boxes, dim=0)
141
+ obj_boxes = torch.stack(obj_boxes, dim=0)
142
+
143
+ # accumulate results (iterate through interaction queries)
144
+ results = []
145
+ for os, ol, vs, ms, sb, ob in zip(obj_scores, obj_labels, verb_scores, matching_scores, sub_boxes, obj_boxes):
146
+ sl = torch.full_like(ol, 0) # self.subject_category_id = 0 in HICO-DET
147
+ l = torch.cat((sl, ol))
148
+ b = torch.cat((sb, ob))
149
+ results.append({'labels': l.to('cpu'), 'boxes': b.to('cpu')})
150
+ vs = vs * os.unsqueeze(1)
151
+ ids = torch.arange(b.shape[0])
152
+ res_dict = {
153
+ 'verb_scores': vs.to('cpu'),
154
+ 'sub_ids': ids[:ids.shape[0] // 2],
155
+ 'obj_ids': ids[ids.shape[0] // 2:],
156
+ 'hoi_recognition_time': hoi_recognition_time
157
+ }
158
+ results[-1].update(res_dict)
159
+ else:
160
+ results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]
161
+
162
+ return results
hotr/models/transformer.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # HOTR official code : hotr/models/transformer.py
3
+ # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4
+ # ------------------------------------------------------------------------
5
+ # Modified from DETR (https://github.com/facebookresearch/detr)
6
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
7
+ # ------------------------------------------------------------------------
8
+ """
9
+ DETR & HOTR Transformer class.
10
+ Copy-paste from torch.nn.Transformer with modifications:
11
+ * positional encodings are passed in MHattention
12
+ * extra LN at the end of encoder is removed
13
+ * decoder returns a stack of activations from all decoding layers
14
+ """
15
+ import copy
16
+ from typing import Optional, List
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn, Tensor
21
+
22
+
23
+ class Transformer(nn.Module):
24
+
25
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
26
+ num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
27
+ activation="relu", normalize_before=False,
28
+ return_intermediate_dec=False):
29
+ super().__init__()
30
+
31
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
32
+ dropout, activation, normalize_before)
33
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
34
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
35
+
36
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
37
+ dropout, activation, normalize_before)
38
+ decoder_norm = nn.LayerNorm(d_model)
39
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
40
+ return_intermediate=return_intermediate_dec)
41
+
42
+ self._reset_parameters()
43
+ self.d_model = d_model
44
+ self.nhead = nhead
45
+
46
+ def _reset_parameters(self):
47
+ for p in self.parameters():
48
+ if p.dim() > 1:
49
+ nn.init.xavier_uniform_(p)
50
+
51
+ def forward(self, src, mask, query_embed, pos_embed,query_obj=None, return_decoder_input=False):
52
+ # flatten NxCxHxW to HWxNxC
53
+ bs, c, h, w = src.shape
54
+ src = src.flatten(2).permute(2, 0, 1)
55
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
56
+
57
+ if query_embed.dim()==2:
58
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
59
+ mask = mask.flatten(1)
60
+
61
+ tgt = torch.zeros_like(query_embed)
62
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
63
+ if query_obj is None:
64
+ hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed)
65
+ else:
66
+
67
+ hs = self.decoder(query_obj, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed)
68
+
69
+ return hs.transpose(1, 2), memory,
70
+
71
+ class TransformerEncoder(nn.Module):
72
+
73
+ def __init__(self, encoder_layer, num_layers, norm=None):
74
+ super().__init__()
75
+ self.layers = _get_clones(encoder_layer, num_layers)
76
+ self.num_layers = num_layers
77
+ self.norm = norm
78
+
79
+ def forward(self, src,
80
+ mask: Optional[Tensor] = None,
81
+ src_key_padding_mask: Optional[Tensor] = None,
82
+ pos: Optional[Tensor] = None):
83
+ output = src
84
+
85
+ for layer in self.layers:
86
+ output = layer(output, src_mask=mask,
87
+ src_key_padding_mask=src_key_padding_mask, pos=pos)
88
+
89
+ if self.norm is not None:
90
+ output = self.norm(output)
91
+
92
+ return output
93
+
94
+
95
+ class TransformerDecoder(nn.Module):
96
+
97
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
98
+ super().__init__()
99
+ self.layers = _get_clones(decoder_layer, num_layers)
100
+ self.num_layers = num_layers
101
+ self.norm = norm
102
+ self.return_intermediate = return_intermediate
103
+
104
+ def forward(self, tgt, memory,
105
+ tgt_mask: Optional[Tensor] = None,
106
+ memory_mask: Optional[Tensor] = None,
107
+ tgt_key_padding_mask: Optional[Tensor] = None,
108
+ memory_key_padding_mask: Optional[Tensor] = None,
109
+ pos: Optional[Tensor] = None,
110
+ query_pos: Optional[Tensor] = None):
111
+ output = tgt
112
+
113
+ intermediate = []
114
+
115
+ for layer in self.layers:
116
+ output = layer(output, memory, tgt_mask=tgt_mask,
117
+ memory_mask=memory_mask,
118
+ tgt_key_padding_mask=tgt_key_padding_mask,
119
+ memory_key_padding_mask=memory_key_padding_mask,
120
+ pos=pos, query_pos=query_pos)
121
+ if self.return_intermediate:
122
+ intermediate.append(self.norm(output))
123
+
124
+ if self.norm is not None:
125
+ output = self.norm(output)
126
+ if self.return_intermediate:
127
+ intermediate.pop()
128
+ intermediate.append(output)
129
+
130
+ if self.return_intermediate:
131
+ return torch.stack(intermediate)
132
+
133
+ return output.unsqueeze(0)
134
+
135
+
136
+ class TransformerEncoderLayer(nn.Module):
137
+
138
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
139
+ activation="relu", normalize_before=False):
140
+ super().__init__()
141
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
142
+ # Implementation of Feedforward model
143
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
144
+ self.dropout = nn.Dropout(dropout)
145
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
146
+
147
+ self.norm1 = nn.LayerNorm(d_model)
148
+ self.norm2 = nn.LayerNorm(d_model)
149
+ self.dropout1 = nn.Dropout(dropout)
150
+ self.dropout2 = nn.Dropout(dropout)
151
+
152
+ self.activation = _get_activation_fn(activation)
153
+ self.normalize_before = normalize_before
154
+
155
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
156
+ return tensor if pos is None else tensor + pos
157
+
158
+ def forward_post(self,
159
+ src,
160
+ src_mask: Optional[Tensor] = None,
161
+ src_key_padding_mask: Optional[Tensor] = None,
162
+ pos: Optional[Tensor] = None):
163
+ q = k = self.with_pos_embed(src, pos)
164
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
165
+ key_padding_mask=src_key_padding_mask)[0]
166
+ src = src + self.dropout1(src2)
167
+ src = self.norm1(src)
168
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
169
+ src = src + self.dropout2(src2)
170
+ src = self.norm2(src)
171
+ return src
172
+
173
+ def forward_pre(self, src,
174
+ src_mask: Optional[Tensor] = None,
175
+ src_key_padding_mask: Optional[Tensor] = None,
176
+ pos: Optional[Tensor] = None):
177
+ src2 = self.norm1(src)
178
+ q = k = self.with_pos_embed(src2, pos)
179
+ src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
180
+ key_padding_mask=src_key_padding_mask)[0]
181
+ src = src + self.dropout1(src2)
182
+ src2 = self.norm2(src)
183
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
184
+ src = src + self.dropout2(src2)
185
+ return src
186
+
187
+ def forward(self, src,
188
+ src_mask: Optional[Tensor] = None,
189
+ src_key_padding_mask: Optional[Tensor] = None,
190
+ pos: Optional[Tensor] = None):
191
+ if self.normalize_before:
192
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
193
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
194
+
195
+
196
+ class TransformerDecoderLayer(nn.Module):
197
+
198
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
199
+ activation="relu", normalize_before=False):
200
+ super().__init__()
201
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
202
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
203
+ # Implementation of Feedforward model
204
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
205
+ self.dropout = nn.Dropout(dropout)
206
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
207
+
208
+ self.norm1 = nn.LayerNorm(d_model)
209
+ self.norm2 = nn.LayerNorm(d_model)
210
+ self.norm3 = nn.LayerNorm(d_model)
211
+ self.dropout1 = nn.Dropout(dropout)
212
+ self.dropout2 = nn.Dropout(dropout)
213
+ self.dropout3 = nn.Dropout(dropout)
214
+
215
+ self.activation = _get_activation_fn(activation)
216
+ self.normalize_before = normalize_before
217
+
218
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
219
+ return tensor if pos is None else tensor + pos
220
+
221
+ def forward_post(self, tgt, memory,
222
+ tgt_mask: Optional[Tensor] = None,
223
+ memory_mask: Optional[Tensor] = None,
224
+ tgt_key_padding_mask: Optional[Tensor] = None,
225
+ memory_key_padding_mask: Optional[Tensor] = None,
226
+ pos: Optional[Tensor] = None,
227
+ query_pos: Optional[Tensor] = None):
228
+
229
+ q = k = self.with_pos_embed(tgt, query_pos)
230
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
231
+ key_padding_mask=tgt_key_padding_mask)[0]
232
+ tgt = tgt + self.dropout1(tgt2)
233
+ tgt = self.norm1(tgt)
234
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
235
+ key=self.with_pos_embed(memory, pos),
236
+ value=memory, attn_mask=memory_mask,
237
+ key_padding_mask=memory_key_padding_mask)[0]
238
+ tgt = tgt + self.dropout2(tgt2)
239
+ tgt = self.norm2(tgt)
240
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
241
+ tgt = tgt + self.dropout3(tgt2)
242
+ tgt = self.norm3(tgt)
243
+ return tgt
244
+
245
+ def forward_pre(self, tgt, memory,
246
+ tgt_mask: Optional[Tensor] = None,
247
+ memory_mask: Optional[Tensor] = None,
248
+ tgt_key_padding_mask: Optional[Tensor] = None,
249
+ memory_key_padding_mask: Optional[Tensor] = None,
250
+ pos: Optional[Tensor] = None,
251
+ query_pos: Optional[Tensor] = None):
252
+ tgt2 = self.norm1(tgt)
253
+ q = k = self.with_pos_embed(tgt2, query_pos)
254
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
255
+ key_padding_mask=tgt_key_padding_mask)[0]
256
+ tgt = tgt + self.dropout1(tgt2)
257
+ tgt2 = self.norm2(tgt)
258
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
259
+ key=self.with_pos_embed(memory, pos),
260
+ value=memory, attn_mask=memory_mask,
261
+ key_padding_mask=memory_key_padding_mask)[0]
262
+ tgt = tgt + self.dropout2(tgt2)
263
+ tgt2 = self.norm3(tgt)
264
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
265
+ tgt = tgt + self.dropout3(tgt2)
266
+ return tgt
267
+
268
+ def forward(self, tgt, memory,
269
+ tgt_mask: Optional[Tensor] = None,
270
+ memory_mask: Optional[Tensor] = None,
271
+ tgt_key_padding_mask: Optional[Tensor] = None,
272
+ memory_key_padding_mask: Optional[Tensor] = None,
273
+ pos: Optional[Tensor] = None,
274
+ query_pos: Optional[Tensor] = None):
275
+ if self.normalize_before:
276
+ return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
277
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
278
+ return self.forward_post(tgt, memory, tgt_mask, memory_mask,
279
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
280
+
281
+
282
+ def _get_clones(module, N):
283
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
284
+
285
+
286
+ def build_transformer(args):
287
+ return Transformer(
288
+ d_model=args.hidden_dim,
289
+ dropout=args.dropout,
290
+ nhead=args.nheads,
291
+ dim_feedforward=args.dim_feedforward,
292
+ num_encoder_layers=args.enc_layers,
293
+ num_decoder_layers=args.dec_layers,
294
+ normalize_before=args.pre_norm,
295
+ return_intermediate_dec=True,
296
+ )
297
+
298
+
299
+ def build_hoi_transformer(args):
300
+ return Transformer(
301
+ d_model=args.hidden_dim,
302
+ dropout=args.dropout,
303
+ nhead=args.hoi_nheads,
304
+ dim_feedforward=args.hoi_dim_feedforward,
305
+ num_encoder_layers=args.hoi_enc_layers,
306
+ num_decoder_layers=args.hoi_dec_layers,
307
+ normalize_before=args.pre_norm,
308
+ return_intermediate_dec=True,
309
+ )
310
+
311
+
312
+ def _get_activation_fn(activation):
313
+ """Return an activation function given a string"""
314
+ if activation == "relu":
315
+ return F.relu
316
+ if activation == "gelu":
317
+ return F.gelu
318
+ if activation == "glu":
319
+ return F.glu
320
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
hotr/util/__init__.py ADDED
File without changes
hotr/util/box_ops.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Utilities for bounding box manipulation and GIoU.
4
+ """
5
+ import torch
6
+ from torchvision.ops.boxes import box_area
7
+
8
+
9
+ def box_cxcywh_to_xyxy(x):
10
+ x_c, y_c, w, h = x.unbind(-1)
11
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
12
+ (x_c + 0.5 * w), (y_c + 0.5 * h)]
13
+ return torch.stack(b, dim=-1)
14
+
15
+
16
+ def box_xyxy_to_cxcywh(x):
17
+ x0, y0, x1, y1 = x.unbind(-1)
18
+ b = [(x0 + x1) / 2, (y0 + y1) / 2,
19
+ (x1 - x0), (y1 - y0)]
20
+ return torch.stack(b, dim=-1)
21
+
22
+
23
+ # modified from torchvision to also return the union
24
+ def box_iou(boxes1, boxes2):
25
+ area1 = box_area(boxes1)
26
+ area2 = box_area(boxes2)
27
+
28
+ lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
29
+ rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
30
+
31
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
32
+ inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
33
+
34
+ union = area1[:, None] + area2 - inter
35
+
36
+ iou = inter / union
37
+ return iou, union
38
+
39
+
40
+ def generalized_box_iou(boxes1, boxes2):
41
+ """
42
+ Generalized IoU from https://giou.stanford.edu/
43
+ The boxes should be in [x0, y0, x1, y1] format
44
+ Returns a [N, M] pairwise matrix, where N = len(boxes1)
45
+ and M = len(boxes2)
46
+ """
47
+ # degenerate boxes gives inf / nan results
48
+ # so do an early check
49
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
50
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
51
+ iou, union = box_iou(boxes1, boxes2)
52
+
53
+ lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
54
+ rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
55
+
56
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
57
+ area = wh[:, :, 0] * wh[:, :, 1]
58
+
59
+ return iou - (area - union) / area
60
+
61
+
62
+ def masks_to_boxes(masks):
63
+ """Compute the bounding boxes around the provided masks
64
+ The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
65
+ Returns a [N, 4] tensors, with the boxes in xyxy format
66
+ """
67
+ if masks.numel() == 0:
68
+ return torch.zeros((0, 4), device=masks.device)
69
+
70
+ h, w = masks.shape[-2:]
71
+
72
+ y = torch.arange(0, h, dtype=torch.float)
73
+ x = torch.arange(0, w, dtype=torch.float)
74
+ y, x = torch.meshgrid(y, x)
75
+
76
+ x_mask = (masks * x.unsqueeze(0))
77
+ x_max = x_mask.flatten(1).max(-1)[0]
78
+ x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
79
+
80
+ y_mask = (masks * y.unsqueeze(0))
81
+ y_max = y_mask.flatten(1).max(-1)[0]
82
+ y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
83
+
84
+ return torch.stack([x_min, y_min, x_max, y_max], 1)
85
+
86
+
87
+ def rescale_bboxes(out_bbox, size):
88
+ img_h, img_w = size
89
+ b = box_cxcywh_to_xyxy(out_bbox)
90
+ b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(out_bbox.get_device())
91
+ return b
92
+
93
+
94
+ def rescale_pairs(out_pairs, size):
95
+ img_h, img_w = size
96
+ h_bbox = out_pairs[:, :4]
97
+ o_bbox = out_pairs[:, 4:]
98
+
99
+ h = box_cxcywh_to_xyxy(h_bbox)
100
+ h = h * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(h_bbox.get_device())
101
+
102
+ obj_mask = (o_bbox[:, 0] != -1)
103
+ if obj_mask.sum() != 0:
104
+ o = box_cxcywh_to_xyxy(o_bbox)
105
+ o = o * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(o_bbox.get_device())
106
+ o_bbox[obj_mask] = o[obj_mask]
107
+ o = o_bbox
108
+ p = torch.cat([h, o], dim=-1)
109
+
110
+ return p
hotr/util/logger.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # HOTR official code : hotr/util/logger.py
3
+ # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4
+ # ------------------------------------------------------------------------
5
+ # Modified from DETR (https://github.com/facebookresearch/detr)
6
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
7
+ # ------------------------------------------------------------------------
8
+ import torch
9
+ import time
10
+ import datetime
11
+ import sys
12
+ from time import sleep
13
+ from collections import defaultdict
14
+
15
+ from hotr.util.misc import SmoothedValue
16
+
17
+ def print_params(model):
18
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
19
+ print('\n[Logger] Number of params: ', n_parameters)
20
+ return n_parameters
21
+
22
+ def print_args(args):
23
+ print('\n[Logger] DETR Arguments:')
24
+ for k, v in vars(args).items():
25
+ if k in [
26
+ 'lr', 'lr_backbone', 'lr_drop',
27
+ 'frozen_weights',
28
+ 'backbone', 'dilation',
29
+ 'position_embedding', 'enc_layers', 'dec_layers', 'num_queries',
30
+ 'dataset_file']:
31
+ print(f'\t{k}: {v}')
32
+
33
+ if args.HOIDet:
34
+ print('\n[Logger] DETR_HOI Arguments:')
35
+ for k, v in vars(args).items():
36
+ if k in [
37
+ 'freeze_enc',
38
+ 'query_flag',
39
+ 'hoi_nheads',
40
+ 'hoi_dim_feedforward',
41
+ 'hoi_dec_layers',
42
+ 'hoi_idx_loss_coef',
43
+ 'hoi_act_loss_coef',
44
+ 'hoi_eos_coef',
45
+ 'object_threshold']:
46
+ print(f'\t{k}: {v}')
47
+
48
+ class MetricLogger(object):
49
+ def __init__(self, mode="test", delimiter="\t"):
50
+ self.meters = defaultdict(SmoothedValue)
51
+ self.delimiter = delimiter
52
+ self.mode = mode
53
+
54
+ def update(self, **kwargs):
55
+ for k, v in kwargs.items():
56
+ if isinstance(v, torch.Tensor):
57
+ v = v.item()
58
+ assert isinstance(v, (float, int))
59
+ self.meters[k].update(v)
60
+
61
+ def __getattr__(self, attr):
62
+ if attr in self.meters:
63
+ return self.meters[attr]
64
+ if attr in self.__dict__:
65
+ return self.__dict__[attr]
66
+ raise AttributeError("'{}' object has no attribute '{}'".format(
67
+ type(self).__name__, attr))
68
+
69
+ def __str__(self):
70
+ loss_str = []
71
+ for name, meter in self.meters.items():
72
+ loss_str.append(
73
+ "{}: {}".format(name, str(meter))
74
+ )
75
+ return self.delimiter.join(loss_str)
76
+
77
+ def synchronize_between_processes(self):
78
+ for meter in self.meters.values():
79
+ meter.synchronize_between_processes()
80
+
81
+ def add_meter(self, name, meter):
82
+ self.meters[name] = meter
83
+
84
+ def log_every(self, iterable, print_freq, header=None):
85
+ i = 0
86
+ if not header:
87
+ header = ''
88
+ start_time = time.time()
89
+ end = time.time()
90
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
91
+ data_time = SmoothedValue(fmt='{avg:.4f}')
92
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
93
+ if torch.cuda.is_available():
94
+ log_msg = self.delimiter.join([
95
+ header,
96
+ '[{0' + space_fmt + '}/{1}]',
97
+ 'eta: {eta}',
98
+ '{meters}',
99
+ 'time: {time}',
100
+ 'data: {data}',
101
+ 'max mem: {memory:.0f}'
102
+ ])
103
+ else:
104
+ log_msg = self.delimiter.join([
105
+ header,
106
+ '[{0' + space_fmt + '}/{1}]',
107
+ 'eta: {eta}',
108
+ '{meters}',
109
+ 'time: {time}',
110
+ 'data: {data}'
111
+ ])
112
+ MB = 1024.0 * 1024.0
113
+ for obj in iterable:
114
+ data_time.update(time.time() - end)
115
+ yield obj
116
+ iter_time.update(time.time() - end)
117
+
118
+ if (i % print_freq == 0 and i !=0) or i == len(iterable) - 1:
119
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
120
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
121
+ if torch.cuda.is_available():
122
+ print(log_msg.format(
123
+ i+1, len(iterable), eta=eta_string,
124
+ meters=str(self),
125
+ time=str(iter_time), data=str(data_time),
126
+ memory=torch.cuda.max_memory_allocated() / MB),
127
+ flush=(self.mode=='test'), end=("\r" if self.mode=='test' else "\n"))
128
+ else:
129
+ print(log_msg.format(
130
+ i+1, len(iterable), eta=eta_string,
131
+ meters=str(self),
132
+ time=str(iter_time), data=str(data_time)),
133
+ flush=(self.mode=='test'), end=("\r" if self.mode=='test' else "\n"))
134
+ else:
135
+ log_interval = self.delimiter.join([header, '[{0' + space_fmt + '}/{1}]'])
136
+ if torch.cuda.is_available(): print(log_interval.format(i+1, len(iterable)), flush=True, end="\r")
137
+ else: print(log_interval.format(i+1, len(iterable)), flush=True, end="\r")
138
+
139
+ i += 1
140
+ end = time.time()
141
+ total_time = time.time() - start_time
142
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
143
+ if self.mode=='test': print("")
144
+ print('[stats] Total Time ({}) : {} ({:.4f} s / it)'.format(
145
+ self.mode, total_time_str, total_time / len(iterable)))
hotr/util/misc.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # HOTR official code : hotr/util/misc.py
3
+ # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4
+ # ------------------------------------------------------------------------
5
+ # Modified from DETR (https://github.com/facebookresearch/detr)
6
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
7
+ # ------------------------------------------------------------------------
8
+ """
9
+ Misc functions, including distributed helpers.
10
+ Mostly copy-paste from torchvision references.
11
+ """
12
+ import os
13
+ import subprocess
14
+ from collections import deque
15
+ import pickle
16
+ import socket
17
+ from typing import Optional, List
18
+ import ast
19
+ import torch
20
+ import torch.distributed as dist
21
+ from torch import Tensor
22
+
23
+ # needed due to empty tensor bug in pytorch and torchvision 0.5
24
+ import torchvision
25
+ if float(torchvision.__version__[:3]) < 0.7:
26
+ from torchvision.ops import _new_empty_tensor
27
+ from torchvision.ops.misc import _output_size
28
+
29
+ os.environ['MASTER_PORT']='8993'
30
+ class SmoothedValue(object):
31
+ """Track a series of values and provide access to smoothed values over a
32
+ window or the global series average.
33
+ """
34
+
35
+ def __init__(self, window_size=20, fmt=None):
36
+ if fmt is None:
37
+ fmt = "{median:.4f} ({global_avg:.4f})"
38
+ self.deque = deque(maxlen=window_size)
39
+ self.total = 0.0
40
+ self.count = 0
41
+ self.fmt = fmt
42
+
43
+ def update(self, value, n=1):
44
+ self.deque.append(value)
45
+ self.count += n
46
+ self.total += value * n
47
+
48
+ def synchronize_between_processes(self):
49
+ """
50
+ Warning: does not synchronize the deque!
51
+ """
52
+ if not is_dist_avail_and_initialized():
53
+ return
54
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
55
+ dist.barrier()
56
+ dist.all_reduce(t)
57
+ t = t.tolist()
58
+ self.count = int(t[0])
59
+ self.total = t[1]
60
+
61
+ @property
62
+ def median(self):
63
+ d = torch.tensor(list(self.deque))
64
+ return d.median().item()
65
+
66
+ @property
67
+ def avg(self):
68
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
69
+ return d.mean().item()
70
+
71
+ @property
72
+ def global_avg(self):
73
+ return self.total / self.count
74
+
75
+ @property
76
+ def max(self):
77
+ return max(self.deque)
78
+
79
+ @property
80
+ def value(self):
81
+ return self.deque[-1]
82
+
83
+ def __str__(self):
84
+ return self.fmt.format(
85
+ median=self.median,
86
+ avg=self.avg,
87
+ global_avg=self.global_avg,
88
+ max=self.max,
89
+ value=self.value)
90
+
91
+
92
+ def all_gather(data):
93
+ """
94
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
95
+ Args:
96
+ data: any picklable object
97
+ Returns:
98
+ list[data]: list of data gathered from each rank
99
+ """
100
+ world_size = get_world_size()
101
+ if world_size == 1:
102
+ return [data]
103
+
104
+ # serialized to a Tensor
105
+ buffer = pickle.dumps(data)
106
+ storage = torch.ByteStorage.from_buffer(buffer)
107
+ tensor = torch.ByteTensor(storage).to("cuda")
108
+
109
+ # obtain Tensor size of each rank
110
+ local_size = torch.tensor([tensor.numel()], device="cuda")
111
+ size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
112
+ dist.all_gather(size_list, local_size)
113
+ size_list = [int(size.item()) for size in size_list]
114
+ max_size = max(size_list)
115
+
116
+ # receiving Tensor from all ranks
117
+ # we pad the tensor because torch all_gather does not support
118
+ # gathering tensors of different shapes
119
+ tensor_list = []
120
+ for _ in size_list:
121
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
122
+ if local_size != max_size:
123
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
124
+ tensor = torch.cat((tensor, padding), dim=0)
125
+ dist.all_gather(tensor_list, tensor)
126
+
127
+ data_list = []
128
+ for size, tensor in zip(size_list, tensor_list):
129
+ buffer = tensor.cpu().numpy().tobytes()[:size]
130
+ data_list.append(pickle.loads(buffer))
131
+
132
+ return data_list
133
+
134
+
135
+ def reduce_dict(input_dict, average=True):
136
+ """
137
+ Args:
138
+ input_dict (dict): all the values will be reduced
139
+ average (bool): whether to do average or sum
140
+ Reduce the values in the dictionary from all processes so that all processes
141
+ have the averaged results. Returns a dict with the same fields as
142
+ input_dict, after reduction.
143
+ """
144
+ world_size = get_world_size()
145
+ if world_size < 2:
146
+ return input_dict
147
+ with torch.no_grad():
148
+ names = []
149
+ values = []
150
+ # sort the keys so that they are consistent across processes
151
+ for k in sorted(input_dict.keys()):
152
+ names.append(k)
153
+ values.append(input_dict[k])
154
+ values = torch.stack(values, dim=0)
155
+ dist.all_reduce(values)
156
+ if average:
157
+ values /= world_size
158
+ reduced_dict = {k: v for k, v in zip(names, values)}
159
+ return reduced_dict
160
+
161
+
162
+ def get_sha():
163
+ cwd = os.path.dirname(os.path.abspath(__file__))
164
+
165
+ def _run(command):
166
+ return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
167
+ sha = 'N/A'
168
+ diff = "clean"
169
+ branch = 'N/A'
170
+ try:
171
+ sha = _run(['git', 'rev-parse', 'HEAD'])
172
+ subprocess.check_output(['git', 'diff'], cwd=cwd)
173
+ diff = _run(['git', 'diff-index', 'HEAD'])
174
+ diff = "has uncommited changes" if diff else "clean"
175
+ branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
176
+ except Exception:
177
+ pass
178
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
179
+ return message
180
+
181
+
182
+ def collate_fn(batch):
183
+ batch = list(zip(*batch))
184
+ batch[0] = nested_tensor_from_tensor_list(batch[0])
185
+ return tuple(batch)
186
+
187
+
188
+ def _max_by_axis(the_list):
189
+ # type: (List[List[int]]) -> List[int]
190
+ maxes = the_list[0]
191
+ for sublist in the_list[1:]:
192
+ for index, item in enumerate(sublist):
193
+ maxes[index] = max(maxes[index], item)
194
+ return maxes
195
+
196
+
197
+ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
198
+ # TODO make this more general
199
+ if tensor_list[0].ndim == 3:
200
+ # TODO make it support different-sized images
201
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
202
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
203
+ batch_shape = [len(tensor_list)] + max_size
204
+ b, c, h, w = batch_shape
205
+ dtype = tensor_list[0].dtype
206
+ device = tensor_list[0].device
207
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
208
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
209
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
210
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
211
+ m[: img.shape[1], :img.shape[2]] = False
212
+ else:
213
+ raise ValueError('not supported')
214
+ return NestedTensor(tensor, mask)
215
+
216
+
217
+ class NestedTensor(object):
218
+ def __init__(self, tensors, mask: Optional[Tensor]):
219
+ self.tensors = tensors
220
+ self.mask = mask
221
+
222
+ def to(self, device):
223
+ # type: (Device) -> NestedTensor # noqa
224
+ cast_tensor = self.tensors.to(device)
225
+ mask = self.mask
226
+ if mask is not None:
227
+ assert mask is not None
228
+ cast_mask = mask.to(device)
229
+ else:
230
+ cast_mask = None
231
+ return NestedTensor(cast_tensor, cast_mask)
232
+
233
+ def decompose(self):
234
+ return self.tensors, self.mask
235
+
236
+ def __repr__(self):
237
+ return str(self.tensors)
238
+
239
+
240
+ def setup_for_distributed(is_master):
241
+ """
242
+ This function disables printing when not in master process
243
+ """
244
+ import builtins as __builtin__
245
+ builtin_print = __builtin__.print
246
+
247
+ def print(*args, **kwargs):
248
+ force = kwargs.pop('force', False)
249
+ if is_master or force:
250
+ builtin_print(*args, **kwargs)
251
+
252
+ __builtin__.print = print
253
+
254
+
255
+ def is_dist_avail_and_initialized():
256
+ if not dist.is_available():
257
+ return False
258
+ if not dist.is_initialized():
259
+ return False
260
+ return True
261
+
262
+
263
+ def get_world_size():
264
+ if not is_dist_avail_and_initialized():
265
+ return 1
266
+ return dist.get_world_size()
267
+
268
+
269
+ def get_rank():
270
+ if not is_dist_avail_and_initialized():
271
+ return 0
272
+ return dist.get_rank()
273
+
274
+
275
+ def is_main_process():
276
+ return get_rank() == 0
277
+
278
+
279
+ def save_on_master(*args, **kwargs):
280
+ if is_main_process():
281
+ torch.save(*args, **kwargs)
282
+
283
+
284
+ def _check_if_valid_ip(ip):
285
+ try:
286
+ socket.inet_aton(ip)
287
+ # legal
288
+ except socket.error:
289
+ # Not legal
290
+ return False
291
+ return True
292
+
293
+ def arg_as_list(s):
294
+ v = ast.literal_eval(s)
295
+ if type(v) is not list:
296
+ raise argparse.ArgumentTypeError("List should be given.")
297
+ return v
298
+
299
+ def _maybe_gethostbyname(addr):
300
+ """to be compatible with Braincloud on which one can access the nodes by their task names.
301
+ Each node has to wait until all the tasks in the group are up on the cloud."""
302
+ if _check_if_valid_ip(addr):
303
+ # If IP address is given, do nothing
304
+ return addr
305
+
306
+ # Otherwise, find the IP address by hostname
307
+ done = False
308
+ retry = 0
309
+ print(f"Get URL by the given hostname '{addr}' in Braincloud..")
310
+ while not done:
311
+ try:
312
+ addr = socket.gethostbyname(addr)
313
+ done = True
314
+ except:
315
+ retry += 1
316
+ print(f"Retrying count: {retry}")
317
+ time.sleep(3)
318
+ print(f"Found the host by IP address: {addr}")
319
+ return addr
320
+
321
+
322
+ def init_distributed_mode(args):
323
+
324
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
325
+ os.environ["MASTER_ADDR"] = _maybe_gethostbyname(os.environ["MASTER_ADDR"])
326
+ args.rank = int(os.environ["RANK"])
327
+ args.world_size = int(os.environ['WORLD_SIZE'])
328
+ args.gpu = int(os.environ['LOCAL_RANK'])
329
+ args.dist_url = 'env://'
330
+ os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count())
331
+ elif 'SLURM_PROCID' in os.environ:
332
+ proc_id = int(os.environ['SLURM_PROCID'])
333
+ ntasks = int(os.environ['SLURM_NTASKS'])
334
+ node_list = os.environ['SLURM_NODELIST']
335
+ num_gpus = torch.cuda.device_count()
336
+ addr = subprocess.getoutput(
337
+ 'scontrol show hostname {} | head -n1'.format(node_list))
338
+ os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500')
339
+ os.environ['MASTER_ADDR'] = addr
340
+ os.environ['WORLD_SIZE'] = str(ntasks)
341
+ os.environ['RANK'] = str(proc_id)
342
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
343
+ os.environ['LOCAL_SIZE'] = str(num_gpus)
344
+ args.dist_url = 'env://'
345
+ args.world_size = ntasks
346
+ args.rank = proc_id
347
+ args.gpu = proc_id % num_gpus
348
+ else:
349
+ print('Not using distributed mode')
350
+ args.distributed = False
351
+ return
352
+
353
+ args.distributed = True
354
+
355
+ torch.cuda.set_device(args.gpu)
356
+ args.dist_backend = 'nccl'
357
+ print('| distributed init (rank {}): {}'.format(
358
+ args.rank, args.dist_url), flush=True)
359
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
360
+ world_size=args.world_size, rank=args.rank)
361
+ torch.distributed.barrier()
362
+ setup_for_distributed(args.rank == 0)
363
+
364
+
365
+ @torch.no_grad()
366
+ def accuracy(output, target, topk=(1,)):
367
+ """Computes the precision@k for the specified values of k"""
368
+ if target.numel() == 0:
369
+ return [torch.zeros([], device=output.device)]
370
+ maxk = max(topk)
371
+ batch_size = target.size(0)
372
+
373
+ _, pred = output.topk(maxk, 1, True, True)
374
+ pred = pred.t()
375
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
376
+
377
+ res = []
378
+ for k in topk:
379
+ correct_k = correct[:k].view(-1).float().sum(0)
380
+ res.append(correct_k.mul_(100.0 / batch_size))
381
+ return res
382
+
383
+
384
+ def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
385
+ # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
386
+ """
387
+ Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
388
+ This will eventually be supported natively by PyTorch, and this
389
+ class can go away.
390
+ """
391
+ if float(torchvision.__version__[:3]) < 0.7:
392
+ if input.numel() > 0:
393
+ return torch.nn.functional.interpolate(
394
+ input, size, scale_factor, mode, align_corners
395
+ )
396
+
397
+ output_shape = _output_size(2, input, size, scale_factor)
398
+ output_shape = list(input.shape[:-2]) + list(output_shape)
399
+ return _new_empty_tensor(input, output_shape)
400
+ else:
401
+ return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
hotr/util/ramp.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018, Curious AI Ltd. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ import numpy as np
9
+
10
+ def sigmoid_rampup(current, rampup_length,max_coef=1.):
11
+ """Exponential rampup from https://arxiv.org/abs/1610.02242"""
12
+ """Modified version from https://github.com/vikasverma1077/GraphMix/blob/master/semisupervised/codes/ramps.py"""
13
+ if rampup_length == 0:
14
+ return max_coef
15
+ else:
16
+ current = np.clip(current, 0.0, rampup_length)
17
+ phase = 1.0 - current / rampup_length
18
+ return float(np.exp(-5.0 * phase * phase))*max_coef
19
+
20
+ def cosine_rampdown(current, rampdown_length,max_coef=1.):
21
+ """Cosine rampdown from https://arxiv.org/abs/1608.03983"""
22
+ assert 0 <= current <= rampdown_length
23
+ return float(.5 * (np.cos(np.pi *current / rampdown_length) + 1))*max_coef
imgs/mainfig.png ADDED
main.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # HOTR official code : main.py
3
+ # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4
+ # ------------------------------------------------------------------------
5
+ # Modified from DETR (https://github.com/facebookresearch/detr)
6
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
7
+ # ------------------------------------------------------------------------
8
+ import argparse
9
+ import datetime
10
+ import json
11
+ import random
12
+ import time
13
+ import multiprocessing
14
+ from pathlib import Path
15
+
16
+ import numpy as np
17
+ import torch
18
+ from torch.utils.data import DataLoader, DistributedSampler
19
+
20
+ import hotr.data.datasets as datasets
21
+ import hotr.util.misc as utils
22
+ from hotr.engine.arg_parser import get_args_parser
23
+ from hotr.data.datasets import build_dataset, get_coco_api_from_dataset
24
+ from hotr.engine.trainer import train_one_epoch
25
+ from hotr.engine import hoi_evaluator, hoi_accumulator
26
+ from hotr.models import build_model
27
+ import wandb
28
+
29
+ from hotr.util.logger import print_params, print_args
30
+
31
+ def save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename):
32
+ # save_ckpt: function for saving checkpoints
33
+ output_dir = Path(args.output_dir)
34
+ if args.output_dir:
35
+ checkpoint_path = output_dir / f'{filename}.pth'
36
+ utils.save_on_master({
37
+ 'model': model_without_ddp.state_dict(),
38
+ 'optimizer': optimizer.state_dict(),
39
+ 'lr_scheduler': lr_scheduler.state_dict(),
40
+ 'epoch': epoch,
41
+ 'args': args,
42
+ }, checkpoint_path)
43
+
44
+ def main(args):
45
+ utils.init_distributed_mode(args)
46
+
47
+ if args.frozen_weights is not None:
48
+ print("Freeze weights for detector")
49
+
50
+ device = torch.device(args.device)
51
+
52
+ # fix the seed for reproducibility
53
+ seed = args.seed + utils.get_rank()
54
+ torch.manual_seed(seed)
55
+ np.random.seed(seed)
56
+ random.seed(seed)
57
+
58
+ # Data Setup
59
+ dataset_train = build_dataset(image_set='train', args=args)
60
+ dataset_val = build_dataset(image_set='val' if not args.eval else 'test', args=args)
61
+ assert dataset_train.num_action() == dataset_val.num_action(), "Number of actions should be the same between splits"
62
+ args.num_classes = dataset_train.num_category()
63
+ args.num_actions = dataset_train.num_action()
64
+ args.action_names = dataset_train.get_actions()
65
+ if args.share_enc: args.hoi_enc_layers = args.enc_layers
66
+ if args.pretrained_dec: args.hoi_dec_layers = args.dec_layers
67
+ if args.dataset_file == 'vcoco':
68
+ # Save V-COCO dataset statistics
69
+ args.valid_ids = np.array(dataset_train.get_object_label_idx()).nonzero()[0]
70
+ args.invalid_ids = np.argwhere(np.array(dataset_train.get_object_label_idx()) == 0).squeeze(1)
71
+ args.human_actions = dataset_train.get_human_action()
72
+ args.object_actions = dataset_train.get_object_action()
73
+ args.num_human_act = dataset_train.num_human_act()
74
+ elif args.dataset_file == 'hico-det':
75
+ args.valid_obj_ids = dataset_train.get_valid_obj_ids()
76
+ print_args(args)
77
+
78
+ if args.distributed:
79
+ sampler_train = DistributedSampler(dataset_train, shuffle=True)
80
+ sampler_val = DistributedSampler(dataset_val, shuffle=False)
81
+ else:
82
+ sampler_train = torch.utils.data.RandomSampler(dataset_train)
83
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
84
+
85
+ batch_sampler_train = torch.utils.data.BatchSampler(
86
+ sampler_train, args.batch_size, drop_last=True)
87
+
88
+ data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
89
+ collate_fn=utils.collate_fn, num_workers=args.num_workers)
90
+ data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
91
+ drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers)
92
+
93
+ # Model Setup
94
+ model, criterion, postprocessors = build_model(args)
95
+ # import pdb;pdb.set_trace()
96
+ model.to(device)
97
+
98
+ model_without_ddp = model
99
+ if args.distributed:
100
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
101
+ model_without_ddp = model.module
102
+ n_parameters = print_params(model)
103
+
104
+ param_dicts = [
105
+ {"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]},
106
+ {
107
+ "params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad],
108
+ "lr": args.lr_backbone,
109
+ },
110
+ ]
111
+ optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay)
112
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
113
+ # lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [1,100])
114
+
115
+
116
+ # Weight Setup
117
+ if args.frozen_weights is not None:
118
+ if args.frozen_weights.startswith('https'):
119
+ checkpoint = torch.hub.load_state_dict_from_url(
120
+ args.frozen_weights, map_location='cpu', check_hash=True)
121
+ else:
122
+ checkpoint = torch.load(args.frozen_weights, map_location='cpu')
123
+ model_without_ddp.detr.load_state_dict(checkpoint['model'])
124
+
125
+ if args.resume:
126
+ if args.resume.startswith('https'):
127
+ checkpoint = torch.hub.load_state_dict_from_url(
128
+ args.resume, map_location='cpu', check_hash=True)
129
+ else:
130
+ checkpoint = torch.load(args.resume, map_location='cpu')
131
+ model_without_ddp.load_state_dict(checkpoint['model'])
132
+ if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
133
+ optimizer.load_state_dict(checkpoint['optimizer'])
134
+ # lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
135
+ args.start_epoch = checkpoint['epoch'] + 1
136
+ # import pdb;pdb.set_trace()
137
+ if args.eval:
138
+ # test only mode
139
+ if args.HOIDet:
140
+ if args.dataset_file == 'vcoco':
141
+ total_res = hoi_evaluator(args, model, criterion, postprocessors, data_loader_val, device)
142
+ sc1, sc2 = hoi_accumulator(args, total_res, True, False)
143
+ elif args.dataset_file == 'hico-det':
144
+ test_stats = hoi_evaluator(args, model, None, postprocessors, data_loader_val, device)
145
+ print(f'| mAP (full)\t\t: {test_stats["mAP"]:.2f}')
146
+ print(f'| mAP (rare)\t\t: {test_stats["mAP rare"]:.2f}')
147
+ print(f'| mAP (non-rare)\t: {test_stats["mAP non-rare"]:.2f}')
148
+ else: raise ValueError(f'dataset {args.dataset_file} is not supported.')
149
+ return
150
+ else:
151
+ test_stats, coco_evaluator = evaluate_coco(model, criterion, postprocessors,
152
+ data_loader_val, base_ds, device, args.output_dir)
153
+ if args.output_dir:
154
+ utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth")
155
+ return
156
+
157
+ # stats
158
+ scenario1, scenario2 = 0, 0
159
+ best_mAP, best_rare, best_non_rare = 0, 0, 0
160
+
161
+ # add argparse
162
+ if args.wandb and utils.get_rank() == 0:
163
+ wandb.init(
164
+ project=args.project_name,
165
+ group=args.group_name,
166
+ name=args.run_name,
167
+ config=args
168
+ )
169
+ wandb.watch(model)
170
+
171
+ # Training starts here!
172
+ # lr_scheduler.step()
173
+ start_time = time.time()
174
+ for epoch in range(args.start_epoch, args.epochs):
175
+ if args.distributed:
176
+ sampler_train.set_epoch(epoch)
177
+ train_stats = train_one_epoch(
178
+ model, criterion, data_loader_train, optimizer, device, epoch, args.epochs, args.ramp_up_epoch,args.ramp_down_epoch,args.hoi_consistency_loss_coef,
179
+ args.clip_max_norm, dataset_file=args.dataset_file, log=args.wandb)
180
+ lr_scheduler.step()
181
+
182
+ # Validation
183
+ if args.validate:
184
+ print('-'*100)
185
+ if args.dataset_file == 'vcoco':
186
+ total_res = hoi_evaluator(args, model, criterion, postprocessors, data_loader_val, device)
187
+ if utils.get_rank() == 0:
188
+ sc1, sc2 = hoi_accumulator(args, total_res, False, args.wandb)
189
+ if sc1 > scenario1:
190
+ scenario1 = sc1
191
+ scenario2 = sc2
192
+ save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename='best')
193
+ print(f'| Scenario #1 mAP : {sc1:.2f} ({scenario1:.2f})')
194
+ print(f'| Scenario #2 mAP : {sc2:.2f} ({scenario2:.2f})')
195
+ elif args.dataset_file == 'hico-det':
196
+ test_stats = hoi_evaluator(args, model, None, postprocessors, data_loader_val, device)
197
+ if utils.get_rank() == 0:
198
+ if test_stats['mAP'] > best_mAP:
199
+ best_mAP = test_stats['mAP']
200
+ best_rare = test_stats['mAP rare']
201
+ best_non_rare = test_stats['mAP non-rare']
202
+ save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename='best')
203
+ print(f'| mAP (full)\t\t: {test_stats["mAP"]:.2f} ({best_mAP:.2f})')
204
+ print(f'| mAP (rare)\t\t: {test_stats["mAP rare"]:.2f} ({best_rare:.2f})')
205
+ print(f'| mAP (non-rare)\t: {test_stats["mAP non-rare"]:.2f} ({best_non_rare:.2f})')
206
+ if args.wandb and utils.get_rank() == 0:
207
+ wandb.log({
208
+ 'mAP': test_stats['mAP'],
209
+ 'mAP rare': test_stats['mAP rare'],
210
+ 'mAP non-rare': test_stats['mAP non-rare']
211
+ })
212
+ print('-'*100)
213
+
214
+ save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename='checkpoint')
215
+ if (epoch + 1) % args.lr_drop == 0 :
216
+ save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename='checkpoint_'+str(epoch))
217
+ # if (epoch + 1) % args.pseudo_epoch == 0 :
218
+ # save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename='checkpoint_pseudo_'+str(epoch))
219
+ total_time = time.time() - start_time
220
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
221
+ print('Training time {}'.format(total_time_str))
222
+ if args.dataset_file == 'vcoco':
223
+ print(f'| Scenario #1 mAP : {scenario1:.2f}')
224
+ print(f'| Scenario #2 mAP : {scenario2:.2f}')
225
+ elif args.dataset_file == 'hico-det':
226
+ print(f'| mAP (full)\t\t: {best_mAP:.2f}')
227
+ print(f'| mAP (rare)\t\t: {best_rare:.2f}')
228
+ print(f'| mAP (non-rare)\t: {best_non_rare:.2f}')
229
+
230
+
231
+ if __name__ == '__main__':
232
+ parser = argparse.ArgumentParser(
233
+ 'End-to-End Human Object Interaction training and evaluation script',
234
+ parents=[get_args_parser()]
235
+ )
236
+ args = parser.parse_args()
237
+ if args.output_dir:
238
+ args.output_dir += f"/{args.group_name}/{args.run_name}/"
239
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
240
+ main(args)
tools/launch.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # --------------------------------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/pytorch/pytorch/blob/173f224570017b4b1a3a1a13d0bff280a54d9cd9/torch/distributed/launch.py
7
+ # --------------------------------------------------------------------------------------------------------------------------
8
+
9
+ r"""
10
+ `torch.distributed.launch` is a module that spawns up multiple distributed
11
+ training processes on each of the training nodes.
12
+ The utility can be used for single-node distributed training, in which one or
13
+ more processes per node will be spawned. The utility can be used for either
14
+ CPU training or GPU training. If the utility is used for GPU training,
15
+ each distributed process will be operating on a single GPU. This can achieve
16
+ well-improved single-node training performance. It can also be used in
17
+ multi-node distributed training, by spawning up multiple processes on each node
18
+ for well-improved multi-node distributed training performance as well.
19
+ This will especially be benefitial for systems with multiple Infiniband
20
+ interfaces that have direct-GPU support, since all of them can be utilized for
21
+ aggregated communication bandwidth.
22
+ In both cases of single-node distributed training or multi-node distributed
23
+ training, this utility will launch the given number of processes per node
24
+ (``--nproc_per_node``). If used for GPU training, this number needs to be less
25
+ or euqal to the number of GPUs on the current system (``nproc_per_node``),
26
+ and each process will be operating on a single GPU from *GPU 0 to
27
+ GPU (nproc_per_node - 1)*.
28
+ **How to use this module:**
29
+ 1. Single-Node multi-process distributed training
30
+ ::
31
+ >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE
32
+ YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other
33
+ arguments of your training script)
34
+ 2. Multi-Node multi-process distributed training: (e.g. two nodes)
35
+ Node 1: *(IP: 192.168.1.1, and has a free port: 1234)*
36
+ ::
37
+ >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE
38
+ --nnodes=2 --node_rank=0 --master_addr="192.168.1.1"
39
+ --master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
40
+ and all other arguments of your training script)
41
+ Node 2:
42
+ ::
43
+ >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE
44
+ --nnodes=2 --node_rank=1 --master_addr="192.168.1.1"
45
+ --master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
46
+ and all other arguments of your training script)
47
+ 3. To look up what optional arguments this module offers:
48
+ ::
49
+ >>> python -m torch.distributed.launch --help
50
+ **Important Notices:**
51
+ 1. This utilty and multi-process distributed (single-node or
52
+ multi-node) GPU training currently only achieves the best performance using
53
+ the NCCL distributed backend. Thus NCCL backend is the recommended backend to
54
+ use for GPU training.
55
+ 2. In your training program, you must parse the command-line argument:
56
+ ``--local_rank=LOCAL_PROCESS_RANK``, which will be provided by this module.
57
+ If your training program uses GPUs, you should ensure that your code only
58
+ runs on the GPU device of LOCAL_PROCESS_RANK. This can be done by:
59
+ Parsing the local_rank argument
60
+ ::
61
+ >>> import argparse
62
+ >>> parser = argparse.ArgumentParser()
63
+ >>> parser.add_argument("--local_rank", type=int)
64
+ >>> args = parser.parse_args()
65
+ Set your device to local rank using either
66
+ ::
67
+ >>> torch.cuda.set_device(arg.local_rank) # before your code runs
68
+ or
69
+ ::
70
+ >>> with torch.cuda.device(arg.local_rank):
71
+ >>> # your code to run
72
+ 3. In your training program, you are supposed to call the following function
73
+ at the beginning to start the distributed backend. You need to make sure that
74
+ the init_method uses ``env://``, which is the only supported ``init_method``
75
+ by this module.
76
+ ::
77
+ torch.distributed.init_process_group(backend='YOUR BACKEND',
78
+ init_method='env://')
79
+ 4. In your training program, you can either use regular distributed functions
80
+ or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your
81
+ training program uses GPUs for training and you would like to use
82
+ :func:`torch.nn.parallel.DistributedDataParallel` module,
83
+ here is how to configure it.
84
+ ::
85
+ model = torch.nn.parallel.DistributedDataParallel(model,
86
+ device_ids=[arg.local_rank],
87
+ output_device=arg.local_rank)
88
+ Please ensure that ``device_ids`` argument is set to be the only GPU device id
89
+ that your code will be operating on. This is generally the local rank of the
90
+ process. In other words, the ``device_ids`` needs to be ``[args.local_rank]``,
91
+ and ``output_device`` needs to be ``args.local_rank`` in order to use this
92
+ utility
93
+ 5. Another way to pass ``local_rank`` to the subprocesses via environment variable
94
+ ``LOCAL_RANK``. This behavior is enabled when you launch the script with
95
+ ``--use_env=True``. You must adjust the subprocess example above to replace
96
+ ``args.local_rank`` with ``os.environ['LOCAL_RANK']``; the launcher
97
+ will not pass ``--local_rank`` when you specify this flag.
98
+ .. warning::
99
+ ``local_rank`` is NOT globally unique: it is only unique per process
100
+ on a machine. Thus, don't use it to decide if you should, e.g.,
101
+ write to a networked filesystem. See
102
+ https://github.com/pytorch/pytorch/issues/12042 for an example of
103
+ how things can go wrong if you don't do this correctly.
104
+ """
105
+
106
+
107
+ import sys
108
+ import subprocess
109
+ import os
110
+ import socket
111
+ from argparse import ArgumentParser, REMAINDER
112
+
113
+ import torch
114
+
115
+
116
+ def parse_args():
117
+ """
118
+ Helper function parsing the command line options
119
+ @retval ArgumentParser
120
+ """
121
+ parser = ArgumentParser(description="PyTorch distributed training launch "
122
+ "helper utilty that will spawn up "
123
+ "multiple distributed processes")
124
+
125
+ # Optional arguments for the launch helper
126
+ parser.add_argument("--nnodes", type=int, default=1,
127
+ help="The number of nodes to use for distributed "
128
+ "training")
129
+ parser.add_argument("--node_rank", type=int, default=0,
130
+ help="The rank of the node for multi-node distributed "
131
+ "training")
132
+ parser.add_argument("--nproc_per_node", type=int, default=1,
133
+ help="The number of processes to launch on each node, "
134
+ "for GPU training, this is recommended to be set "
135
+ "to the number of GPUs in your system so that "
136
+ "each process can be bound to a single GPU.")
137
+ parser.add_argument("--master_addr", default="127.0.0.1", type=str,
138
+ help="Master node (rank 0)'s address, should be either "
139
+ "the IP address or the hostname of node 0, for "
140
+ "single node multi-proc training, the "
141
+ "--master_addr can simply be 127.0.0.1")
142
+ parser.add_argument("--master_port", default=29500, type=int,
143
+ help="Master node (rank 0)'s free port that needs to "
144
+ "be used for communciation during distributed "
145
+ "training")
146
+
147
+ # positional
148
+ parser.add_argument("training_script", type=str,
149
+ help="The full path to the single GPU training "
150
+ "program/script to be launched in parallel, "
151
+ "followed by all the arguments for the "
152
+ "training script")
153
+
154
+ # rest from the training program
155
+ parser.add_argument('training_script_args', nargs=REMAINDER)
156
+ return parser.parse_args()
157
+
158
+
159
+ def main():
160
+ args = parse_args()
161
+
162
+ # world size in terms of number of processes
163
+ dist_world_size = args.nproc_per_node * args.nnodes
164
+
165
+ # set PyTorch distributed related environmental variables
166
+ current_env = os.environ.copy()
167
+ current_env["MASTER_ADDR"] = args.master_addr
168
+ current_env["MASTER_PORT"] = str(args.master_port)
169
+ current_env["WORLD_SIZE"] = str(dist_world_size)
170
+
171
+ processes = []
172
+
173
+ for local_rank in range(0, args.nproc_per_node):
174
+ # each process's rank
175
+ dist_rank = args.nproc_per_node * args.node_rank + local_rank
176
+ current_env["RANK"] = str(dist_rank)
177
+ current_env["LOCAL_RANK"] = str(local_rank)
178
+
179
+ cmd = [args.training_script] + args.training_script_args
180
+
181
+ process = subprocess.Popen(cmd, env=current_env)
182
+ processes.append(process)
183
+
184
+ for process in processes:
185
+ process.wait()
186
+ if process.returncode != 0:
187
+ raise subprocess.CalledProcessError(returncode=process.returncode,
188
+ cmd=process.args)
189
+
190
+
191
+ if __name__ == "__main__":
192
+ main()
tools/run_dist_launch.sh ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # ------------------------------------------------------------------------
3
+ # Deformable DETR
4
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+
8
+ set -x
9
+
10
+ GPUS=$1
11
+ RUN_COMMAND=${@:2}
12
+ if [ $GPUS -lt 8 ]; then
13
+ GPUS_PER_NODE=${GPUS_PER_NODE:-$GPUS}
14
+ else
15
+ GPUS_PER_NODE=${GPUS_PER_NODE:-8}
16
+ fi
17
+ MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
18
+ MASTER_PORT=${MASTER_PORT:-"29500"}
19
+ NODE_RANK=${NODE_RANK:-0}
20
+
21
+ let "NNODES=GPUS/GPUS_PER_NODE"
22
+
23
+ python ./tools/launch.py \
24
+ --nnodes ${NNODES} \
25
+ --node_rank ${NODE_RANK} \
26
+ --master_addr ${MASTER_ADDR} \
27
+ --master_port ${MASTER_PORT} \
28
+ --nproc_per_node ${GPUS_PER_NODE} \
29
+ ${RUN_COMMAND}
tools/run_dist_slurm.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # --------------------------------------------------------------------------------------------------------------------------
3
+ # Deformable DETR
4
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # --------------------------------------------------------------------------------------------------------------------------
7
+ # Modified from https://github.com/open-mmlab/mmdetection/blob/3b53fe15d87860c6941f3dda63c0f27422da6266/tools/slurm_train.sh
8
+ # --------------------------------------------------------------------------------------------------------------------------
9
+
10
+ set -x
11
+
12
+ PARTITION=$1
13
+ JOB_NAME=$2
14
+ GPUS=$3
15
+ RUN_COMMAND=${@:4}
16
+ if [ $GPUS -lt 8 ]; then
17
+ GPUS_PER_NODE=${GPUS_PER_NODE:-$GPUS}
18
+ else
19
+ GPUS_PER_NODE=${GPUS_PER_NODE:-8}
20
+ fi
21
+ CPUS_PER_TASK=${CPUS_PER_TASK:-4}
22
+ SRUN_ARGS=${SRUN_ARGS:-""}
23
+
24
+ srun -p ${PARTITION} \
25
+ --job-name=${JOB_NAME} \
26
+ --gres=gpu:${GPUS_PER_NODE} \
27
+ --ntasks=${GPUS} \
28
+ --ntasks-per-node=${GPUS_PER_NODE} \
29
+ --cpus-per-task=${CPUS_PER_TASK} \
30
+ --kill-on-bad-exit=1 \
31
+ ${SRUN_ARGS} \
32
+ ${RUN_COMMAND}
33
+
v-coco ADDED
@@ -0,0 +1 @@
 
 
1
+ /data/public/rw/datasets/v-coco/