Spaces:
Runtime error
Runtime error
root
commited on
Commit
·
5e0b9df
0
Parent(s):
initial commit
Browse files- .DS_Store +0 -0
- .gitignore +135 -0
- LICENSE +201 -0
- NOTICE +39 -0
- README.md +130 -0
- configs/hico_train.sh +40 -0
- configs/vcoco_train.sh +42 -0
- hico_20160224_det +1 -0
- hotr/data/datasets/__init__.py +24 -0
- hotr/data/datasets/builtin_meta.py +110 -0
- hotr/data/datasets/coco.py +156 -0
- hotr/data/datasets/hico.py +243 -0
- hotr/data/datasets/vcoco.py +467 -0
- hotr/data/evaluators/coco_eval.py +256 -0
- hotr/data/evaluators/hico_eval.py +242 -0
- hotr/data/evaluators/vcoco_eval.py +57 -0
- hotr/data/transforms/transforms.py +387 -0
- hotr/engine/__init__.py +14 -0
- hotr/engine/arg_parser.py +163 -0
- hotr/engine/evaluator_coco.py +62 -0
- hotr/engine/evaluator_hico.py +55 -0
- hotr/engine/evaluator_vcoco.py +87 -0
- hotr/engine/trainer.py +73 -0
- hotr/metrics/utils.py +90 -0
- hotr/metrics/vcoco/ap_agent.py +104 -0
- hotr/metrics/vcoco/ap_role.py +193 -0
- hotr/models/__init__.py +5 -0
- hotr/models/backbone.py +118 -0
- hotr/models/criterion.py +349 -0
- hotr/models/detr.py +187 -0
- hotr/models/detr_matcher.py +81 -0
- hotr/models/feed_forward.py +16 -0
- hotr/models/hotr.py +241 -0
- hotr/models/hotr_matcher.py +216 -0
- hotr/models/position_encoding.py +89 -0
- hotr/models/post_process.py +162 -0
- hotr/models/transformer.py +320 -0
- hotr/util/__init__.py +0 -0
- hotr/util/box_ops.py +110 -0
- hotr/util/logger.py +145 -0
- hotr/util/misc.py +401 -0
- hotr/util/ramp.py +23 -0
- imgs/mainfig.png +0 -0
- main.py +240 -0
- tools/launch.py +192 -0
- tools/run_dist_launch.sh +29 -0
- tools/run_dist_slurm.sh +33 -0
- v-coco +1 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
pip-wheel-metadata/
|
| 24 |
+
share/python-wheels/
|
| 25 |
+
*.egg-info/
|
| 26 |
+
.installed.cfg
|
| 27 |
+
*.egg
|
| 28 |
+
MANIFEST
|
| 29 |
+
|
| 30 |
+
# PyInstaller
|
| 31 |
+
# Usually these files are written by a python script from a template
|
| 32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 33 |
+
*.manifest
|
| 34 |
+
*.spec
|
| 35 |
+
|
| 36 |
+
# Installer logs
|
| 37 |
+
pip-log.txt
|
| 38 |
+
pip-delete-this-directory.txt
|
| 39 |
+
|
| 40 |
+
# Unit test / coverage reports
|
| 41 |
+
htmlcov/
|
| 42 |
+
.tox/
|
| 43 |
+
.nox/
|
| 44 |
+
.coverage
|
| 45 |
+
.coverage.*
|
| 46 |
+
.cache
|
| 47 |
+
nosetests.xml
|
| 48 |
+
coverage.xml
|
| 49 |
+
*.cover
|
| 50 |
+
*.py,cover
|
| 51 |
+
.hypothesis/
|
| 52 |
+
.pytest_cache/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
target/
|
| 76 |
+
|
| 77 |
+
# Jupyter Notebook
|
| 78 |
+
.ipynb_checkpoints
|
| 79 |
+
|
| 80 |
+
# IPython
|
| 81 |
+
profile_default/
|
| 82 |
+
ipython_config.py
|
| 83 |
+
|
| 84 |
+
# pyenv
|
| 85 |
+
.python-version
|
| 86 |
+
|
| 87 |
+
# pipenv
|
| 88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 91 |
+
# install all needed dependencies.
|
| 92 |
+
#Pipfile.lock
|
| 93 |
+
|
| 94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 95 |
+
__pypackages__/
|
| 96 |
+
|
| 97 |
+
# Celery stuff
|
| 98 |
+
celerybeat-schedule
|
| 99 |
+
celerybeat.pid
|
| 100 |
+
|
| 101 |
+
# SageMath parsed files
|
| 102 |
+
*.sage.py
|
| 103 |
+
|
| 104 |
+
# Environments
|
| 105 |
+
.env
|
| 106 |
+
.venv
|
| 107 |
+
env/
|
| 108 |
+
venv/
|
| 109 |
+
ENV/
|
| 110 |
+
env.bak/
|
| 111 |
+
venv.bak/
|
| 112 |
+
|
| 113 |
+
# Spyder project settings
|
| 114 |
+
.spyderproject
|
| 115 |
+
.spyproject
|
| 116 |
+
|
| 117 |
+
# Rope project settings
|
| 118 |
+
.ropeproject
|
| 119 |
+
|
| 120 |
+
# mkdocs documentation
|
| 121 |
+
/site
|
| 122 |
+
|
| 123 |
+
# mypy
|
| 124 |
+
.mypy_cache/
|
| 125 |
+
.dmypy.json
|
| 126 |
+
dmypy.json
|
| 127 |
+
|
| 128 |
+
# Pyre type checker
|
| 129 |
+
.pyre/
|
| 130 |
+
wandb/
|
| 131 |
+
checkpoints/
|
| 132 |
+
|
| 133 |
+
# old version
|
| 134 |
+
hotr/models/hotr_v1.py
|
| 135 |
+
Makefile
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "{}"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 2021 KAKAO BRAIN Corp. All Rights Reserved.
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
NOTICE
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
===============================================================================
|
| 2 |
+
DETR' Apache License 2.0
|
| 3 |
+
===============================================================================
|
| 4 |
+
The implementation code is based on the implementation in DETR
|
| 5 |
+
(https://github.com/facebookresearch/detr).
|
| 6 |
+
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
| 7 |
+
Copyright (c) 2020 Facebook
|
| 8 |
+
|
| 9 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
you may not use this file except in compliance with the License.
|
| 11 |
+
You may obtain a copy of the License at
|
| 12 |
+
|
| 13 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
|
| 15 |
+
Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
See the License for the specific language governing permissions and
|
| 19 |
+
limitations under the License.
|
| 20 |
+
|
| 21 |
+
===============================================================================
|
| 22 |
+
QPIC' Apache License 2.0
|
| 23 |
+
===============================================================================
|
| 24 |
+
The implementation code is based on the implementation in QPIC
|
| 25 |
+
(https://github.com/hitachi-rd-cv/qpic).
|
| 26 |
+
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
| 27 |
+
Copyright (c) 2021 Hitachi
|
| 28 |
+
|
| 29 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 30 |
+
you may not use this file except in compliance with the License.
|
| 31 |
+
You may obtain a copy of the License at
|
| 32 |
+
|
| 33 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 34 |
+
|
| 35 |
+
Unless required by applicable law or agreed to in writing, software
|
| 36 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 37 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 38 |
+
See the License for the specific language governing permissions and
|
| 39 |
+
limitations under the License.
|
README.md
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CPC_HOTR
|
| 2 |
+
|
| 3 |
+
This repository contains the application of [Cross-Path Consistency Learning](https://arxiv.org/abs/2204.04836) at [HOTR](https://arxiv.org/abs/2104.13682), based on the official implementation of QPIC in [here](https://github.com/kakaobrain/HOTR).
|
| 4 |
+
|
| 5 |
+
<div align="center">
|
| 6 |
+
<img src=".github/mainfig.png" width="900px" />
|
| 7 |
+
</div>
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
## 1. Environmental Setup
|
| 11 |
+
```bash
|
| 12 |
+
$ conda create -n HOTR_CPC python=3.7
|
| 13 |
+
$ conda install -c pytorch pytorch torchvision # PyTorch 1.7.1, torchvision 0.8.2, CUDA=11.0
|
| 14 |
+
$ conda install cython scipy
|
| 15 |
+
$ pip install pycocotools
|
| 16 |
+
$ pip install opencv-python
|
| 17 |
+
$ pip install wandb
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
## 2. HOI dataset setup
|
| 21 |
+
Our current version of HOTR supports the experiments for both [V-COCO](https://github.com/s-gupta/v-coco) and [HICO-DET](https://drive.google.com/file/d/1QZcJmGVlF9f4h-XLWe9Gkmnmj2z1gSnk/view) dataset.
|
| 22 |
+
Download the dataset under the pulled directory.
|
| 23 |
+
For HICO-DET, we use the [annotation files](https://drive.google.com/file/d/1QZcJmGVlF9f4h-XLWe9Gkmnmj2z1gSnk/view) provided by the PPDM authors.
|
| 24 |
+
Download the [list of actions](https://drive.google.com/open?id=1EeHNHuYyJI-qqDk_-5nay7Mb07tzZLsl) as `list_action.txt` and place them under the unballed hico-det directory.
|
| 25 |
+
Below we present how you should place the files.
|
| 26 |
+
```bash
|
| 27 |
+
# V-COCO setup
|
| 28 |
+
$ git clone https://github.com/s-gupta/v-coco.git
|
| 29 |
+
$ cd v-coco
|
| 30 |
+
$ ln -s [:COCO_DIR] coco/images # COCO_DIR contains images of train2014 & val2014
|
| 31 |
+
$ python script_pick_annotations.py [:COCO_DIR]/annotations
|
| 32 |
+
|
| 33 |
+
# HICO-DET setup
|
| 34 |
+
$ tar -zxvf hico_20160224_det.tar.gz # move the unballed folder under the pulled repository
|
| 35 |
+
|
| 36 |
+
# dataset setup
|
| 37 |
+
HOTR
|
| 38 |
+
│─ v-coco
|
| 39 |
+
│ │─ data
|
| 40 |
+
│ │ │─ instances_vcoco_all_2014.json
|
| 41 |
+
│ │ :
|
| 42 |
+
│ └─ coco
|
| 43 |
+
│ │─ images
|
| 44 |
+
│ │ │─ train2014
|
| 45 |
+
│ │ │ │─ COCO_train2014_000000000009.jpg
|
| 46 |
+
│ │ │ :
|
| 47 |
+
│ │ └─ val2014
|
| 48 |
+
│ │ │─ COCO_val2014_000000000042.jpg
|
| 49 |
+
: : :
|
| 50 |
+
│─ hico_20160224_det
|
| 51 |
+
│ │─ list_action.txt
|
| 52 |
+
│ │─ annotations
|
| 53 |
+
│ │ │─ trainval_hico.json
|
| 54 |
+
│ │ │─ test_hico.json
|
| 55 |
+
│ │ └─ corre_hico.npy
|
| 56 |
+
: :
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
If you wish to download the datasets on our own directory, simply change the 'data_path' argument to the directory you have downloaded the datasets.
|
| 60 |
+
```bash
|
| 61 |
+
--data_path [:your_own_directory]/[v-coco/hico_20160224_det]
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
## 3. Training
|
| 65 |
+
After the preparation, you can start the training with the following command.
|
| 66 |
+
|
| 67 |
+
For the HICO-DET training.
|
| 68 |
+
```
|
| 69 |
+
GPUS_PER_NODE=8 ./tools/run_dist_launch.sh 8 ./configs/hico_train.sh
|
| 70 |
+
```
|
| 71 |
+
For the V-COCO training.
|
| 72 |
+
```
|
| 73 |
+
GPUS_PER_NODE=8 ./tools/run_dist_launch.sh 8 ./configs/vcoco_train.sh
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
## 4. Evaluation
|
| 77 |
+
For evaluation of main inference path P1 (x->HOI), `--path_id` should be set to 0.
|
| 78 |
+
Indexes of Augmented paths are range to 1~3. (1: x->HO->I, 2: x->HI->O, 3: x->OI->H)
|
| 79 |
+
|
| 80 |
+
HICODET
|
| 81 |
+
```
|
| 82 |
+
python -m torch.distributed.launch \
|
| 83 |
+
--nproc_per_node=8 \
|
| 84 |
+
--use_env main.py \
|
| 85 |
+
--batch_size 2 \
|
| 86 |
+
--HOIDet \
|
| 87 |
+
--path_id 0 \
|
| 88 |
+
--share_enc \
|
| 89 |
+
--pretrained_dec \
|
| 90 |
+
--share_dec_param \
|
| 91 |
+
--num_hoi_queries [:query_num] \
|
| 92 |
+
--object_threshold 0 \
|
| 93 |
+
--temperature 0.2 \ # use the exact same temperature value that you used during training!
|
| 94 |
+
--no_aux_loss \
|
| 95 |
+
--eval \
|
| 96 |
+
--dataset_file hico-det \
|
| 97 |
+
--data_path hico_20160224_det \
|
| 98 |
+
--resume checkpoints/hico_det/hico_[:query_num].pth
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
VCOCO
|
| 102 |
+
```
|
| 103 |
+
python -m torch.distributed.launch \
|
| 104 |
+
--nproc_per_node=8 \
|
| 105 |
+
--use_env vcoco_main.py \
|
| 106 |
+
--batch_size 2 \
|
| 107 |
+
--HOIDet \
|
| 108 |
+
--path_id 0 \
|
| 109 |
+
--share_enc \
|
| 110 |
+
--share_dec_param \
|
| 111 |
+
--pretrained_dec \
|
| 112 |
+
--num_hoi_queries [:query_num] \
|
| 113 |
+
--temperature 0.05 \ # use the exact same temperature value that you used during training!
|
| 114 |
+
--object_threshold 0 \
|
| 115 |
+
--no_aux_loss \
|
| 116 |
+
--eval \
|
| 117 |
+
--dataset_file vcoco \
|
| 118 |
+
--data_path v-coco \
|
| 119 |
+
--resume checkpoints/vcoco/vcoco_[:query_num].pth
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
## Citation
|
| 123 |
+
```
|
| 124 |
+
@inproceedings{park2022consistency,
|
| 125 |
+
title={Consistency Learning via Decoding Path Augmentation for Transformers in Human Object Interaction Detection},
|
| 126 |
+
author={Park, Jihwan and Lee, SeungJun and Heo, Hwan and Choi, Hyeong Kyu and Kim, Hyunwoo J},
|
| 127 |
+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
| 128 |
+
year={2022}
|
| 129 |
+
}
|
| 130 |
+
```
|
configs/hico_train.sh
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
set -x
|
| 4 |
+
|
| 5 |
+
EXP_DIR=logs_run_001
|
| 6 |
+
PY_ARGS=${@:1}
|
| 7 |
+
|
| 8 |
+
python -u main.py \
|
| 9 |
+
--project_name CPC_HOTR_HICODET \
|
| 10 |
+
--run_name ${EXP_DIR} \
|
| 11 |
+
--HOIDet \
|
| 12 |
+
--validate \
|
| 13 |
+
--share_enc \
|
| 14 |
+
--pretrained_dec \
|
| 15 |
+
--use_consis \
|
| 16 |
+
--share_dec_param \
|
| 17 |
+
--epochs 90 \
|
| 18 |
+
--lr_drop 60 \
|
| 19 |
+
--lr 1e-4 \
|
| 20 |
+
--lr_backbone 1e-5 \
|
| 21 |
+
--ramp_up_epoch 30 \
|
| 22 |
+
--path_id 0 \
|
| 23 |
+
--num_hoi_queries 16 \
|
| 24 |
+
--set_cost_idx 20 \
|
| 25 |
+
--hoi_idx_loss_coef 1 \
|
| 26 |
+
--hoi_act_loss_coef 10 \
|
| 27 |
+
--backbone resnet50 \
|
| 28 |
+
--hoi_consistency_loss_coef 0.2 \
|
| 29 |
+
--hoi_idx_consistency_loss_coef 1 \
|
| 30 |
+
--hoi_act_consistency_loss_coef 2 \
|
| 31 |
+
--hoi_eos_coef 0.1 \
|
| 32 |
+
--temperature 0.2 \
|
| 33 |
+
--no_aux_loss \
|
| 34 |
+
--hoi_aux_loss \
|
| 35 |
+
--dataset_file hico-det \
|
| 36 |
+
--data_path hico_20160224_det \
|
| 37 |
+
--frozen_weights https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth \
|
| 38 |
+
--output_dir checkpoints/hico_det/ \
|
| 39 |
+
--augpath_name [\'p2\',\'p3\',\'p4\'] \
|
| 40 |
+
${PY_ARGS}
|
configs/vcoco_train.sh
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
set -x
|
| 4 |
+
|
| 5 |
+
EXP_DIR=logs_run_001
|
| 6 |
+
PY_ARGS=${@:1}
|
| 7 |
+
|
| 8 |
+
python -u main.py \
|
| 9 |
+
--project_name CPC_HOTR_VCOCO \
|
| 10 |
+
--run_name ${EXP_DIR} \
|
| 11 |
+
--HOIDet \
|
| 12 |
+
--validate \
|
| 13 |
+
--share_enc \
|
| 14 |
+
--pretrained_dec \
|
| 15 |
+
--use_consis \
|
| 16 |
+
--share_dec_param \
|
| 17 |
+
--epochs 90 \
|
| 18 |
+
--lr_drop 60 \
|
| 19 |
+
--lr 1e-4 \
|
| 20 |
+
--lr_backbone 1e-5 \
|
| 21 |
+
--ramp_up_epoch 30 \
|
| 22 |
+
--path_id 0 \
|
| 23 |
+
--num_hoi_queries 16 \
|
| 24 |
+
--set_cost_idx 10 \
|
| 25 |
+
--hoi_idx_loss_coef 1 \
|
| 26 |
+
--hoi_act_loss_coef 10 \
|
| 27 |
+
--backbone resnet50 \
|
| 28 |
+
--hoi_consistency_loss_coef 1 \
|
| 29 |
+
--hoi_idx_consistency_loss_coef 1 \
|
| 30 |
+
--hoi_act_consistency_loss_coef 10 \
|
| 31 |
+
--stop_grad_stage \
|
| 32 |
+
--hoi_eos_coef 0.1 \
|
| 33 |
+
--temperature 0.05 \
|
| 34 |
+
--no_aux_loss \
|
| 35 |
+
--hoi_aux_loss \
|
| 36 |
+
--dataset_file vcoco \
|
| 37 |
+
--data_path v-coco \
|
| 38 |
+
--frozen_weights https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth \
|
| 39 |
+
--output_dir checkpoints/vcoco/ \
|
| 40 |
+
--augpath_name [\'p2\',\'p3\',\'p4\'] \
|
| 41 |
+
${PY_ARGS}
|
| 42 |
+
|
hico_20160224_det
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
/data/public/rw/datasets/hico_20160224_det/
|
hotr/data/datasets/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
import torch.utils.data
|
| 3 |
+
import torchvision
|
| 4 |
+
|
| 5 |
+
from hotr.data.datasets.coco import build as build_coco
|
| 6 |
+
from hotr.data.datasets.vcoco import build as build_vcoco
|
| 7 |
+
from hotr.data.datasets.hico import build as build_hico
|
| 8 |
+
|
| 9 |
+
def get_coco_api_from_dataset(dataset):
|
| 10 |
+
for _ in range(10): # what is this for?
|
| 11 |
+
if isinstance(dataset, torch.utils.data.Subset):
|
| 12 |
+
dataset = dataset.dataset
|
| 13 |
+
if isinstance(dataset, torchvision.datasets.CocoDetection):
|
| 14 |
+
return dataset.coco
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def build_dataset(image_set, args):
|
| 18 |
+
if args.dataset_file == 'coco':
|
| 19 |
+
return build_coco(image_set, args)
|
| 20 |
+
elif args.dataset_file == 'vcoco':
|
| 21 |
+
return build_vcoco(image_set, args)
|
| 22 |
+
elif args.dataset_file == 'hico-det':
|
| 23 |
+
return build_hico(image_set, args)
|
| 24 |
+
raise ValueError(f'dataset {args.dataset_file} not supported')
|
hotr/data/datasets/builtin_meta.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
COCO_CATEGORIES = [
|
| 2 |
+
{"color": [], "isthing": 0, "id": 0, "name": "N/A"},
|
| 3 |
+
{"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
|
| 4 |
+
{"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"},
|
| 5 |
+
{"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"},
|
| 6 |
+
{"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"},
|
| 7 |
+
{"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"},
|
| 8 |
+
{"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"},
|
| 9 |
+
{"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"},
|
| 10 |
+
{"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"},
|
| 11 |
+
{"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"},
|
| 12 |
+
{"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"},
|
| 13 |
+
{"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"},
|
| 14 |
+
{"color": [], "isthing": 0, "id": 12, "name": "N/A"},
|
| 15 |
+
{"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"},
|
| 16 |
+
{"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"},
|
| 17 |
+
{"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"},
|
| 18 |
+
{"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"},
|
| 19 |
+
{"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"},
|
| 20 |
+
{"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"},
|
| 21 |
+
{"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"},
|
| 22 |
+
{"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"},
|
| 23 |
+
{"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"},
|
| 24 |
+
{"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"},
|
| 25 |
+
{"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"},
|
| 26 |
+
{"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"},
|
| 27 |
+
{"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"},
|
| 28 |
+
{"color": [], "isthing": 0, "id": 26, "name": "N/A"},
|
| 29 |
+
{"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"},
|
| 30 |
+
{"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"},
|
| 31 |
+
{"color": [], "isthing": 0, "id": 29, "name": "N/A"},
|
| 32 |
+
{"color": [], "isthing": 0, "id": 30, "name": "N/A"},
|
| 33 |
+
{"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"},
|
| 34 |
+
{"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"},
|
| 35 |
+
{"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"},
|
| 36 |
+
{"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"},
|
| 37 |
+
{"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"},
|
| 38 |
+
{"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"},
|
| 39 |
+
{"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"},
|
| 40 |
+
{"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"},
|
| 41 |
+
{"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"},
|
| 42 |
+
{"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"},
|
| 43 |
+
{"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"},
|
| 44 |
+
{"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"},
|
| 45 |
+
{"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"},
|
| 46 |
+
{"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"},
|
| 47 |
+
{"color": [], "isthing": 0, "id": 45, "name": "N/A"},
|
| 48 |
+
{"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"},
|
| 49 |
+
{"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"},
|
| 50 |
+
{"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"},
|
| 51 |
+
{"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"},
|
| 52 |
+
{"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"},
|
| 53 |
+
{"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"},
|
| 54 |
+
{"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"},
|
| 55 |
+
{"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"},
|
| 56 |
+
{"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"},
|
| 57 |
+
{"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"},
|
| 58 |
+
{"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"},
|
| 59 |
+
{"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"},
|
| 60 |
+
{"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"},
|
| 61 |
+
{"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"},
|
| 62 |
+
{"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"},
|
| 63 |
+
{"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"},
|
| 64 |
+
{"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"},
|
| 65 |
+
{"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"},
|
| 66 |
+
{"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"},
|
| 67 |
+
{"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"},
|
| 68 |
+
{"color": [], "isthing": 0, "id": 66, "name": "N/A"},
|
| 69 |
+
{"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"},
|
| 70 |
+
{"color": [], "isthing": 0, "id": 68, "name": "N/A"},
|
| 71 |
+
{"color": [], "isthing": 0, "id": 69, "name": "N/A"},
|
| 72 |
+
{"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"},
|
| 73 |
+
{"color": [], "isthing": 0, "id": 71, "name": "N/A"},
|
| 74 |
+
{"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"},
|
| 75 |
+
{"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"},
|
| 76 |
+
{"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"},
|
| 77 |
+
{"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"},
|
| 78 |
+
{"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"},
|
| 79 |
+
{"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"},
|
| 80 |
+
{"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"},
|
| 81 |
+
{"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"},
|
| 82 |
+
{"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"},
|
| 83 |
+
{"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"},
|
| 84 |
+
{"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"},
|
| 85 |
+
{"color": [], "isthing": 0, "id": 83, "name": "N/A"},
|
| 86 |
+
{"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"},
|
| 87 |
+
{"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"},
|
| 88 |
+
{"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"},
|
| 89 |
+
{"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"},
|
| 90 |
+
{"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"},
|
| 91 |
+
{"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"},
|
| 92 |
+
{"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"},
|
| 93 |
+
]
|
| 94 |
+
|
| 95 |
+
def _get_coco_instances_meta():
|
| 96 |
+
thing_ids = [k["id"] for k in COCO_CATEGORIES if k["isthing"] == 1]
|
| 97 |
+
assert len(thing_ids) == 80, f"Length of thing ids : {len(thing_ids)}"
|
| 98 |
+
|
| 99 |
+
thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
|
| 100 |
+
thing_classes = [k["name"] for k in COCO_CATEGORIES if k["isthing"] == 1]
|
| 101 |
+
thing_colors = [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 1]
|
| 102 |
+
|
| 103 |
+
coco_classes = [k["name"] for k in COCO_CATEGORIES]
|
| 104 |
+
|
| 105 |
+
return {
|
| 106 |
+
"thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
|
| 107 |
+
"thing_classes": thing_classes,
|
| 108 |
+
"thing_colors": thing_colors,
|
| 109 |
+
"coco_classes": coco_classes,
|
| 110 |
+
}
|
hotr/data/datasets/coco.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
"""
|
| 3 |
+
COCO dataset which returns image_id for evaluation.
|
| 4 |
+
Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
|
| 5 |
+
"""
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.utils.data
|
| 10 |
+
import torchvision
|
| 11 |
+
from pycocotools import mask as coco_mask
|
| 12 |
+
|
| 13 |
+
import hotr.data.transforms.transforms as T
|
| 14 |
+
|
| 15 |
+
class CocoDetection(torchvision.datasets.CocoDetection):
|
| 16 |
+
def __init__(self, img_folder, ann_file, transforms, return_masks):
|
| 17 |
+
super(CocoDetection, self).__init__(img_folder, ann_file)
|
| 18 |
+
self._transforms = transforms
|
| 19 |
+
self.prepare = ConvertCocoPolysToMask(return_masks)
|
| 20 |
+
|
| 21 |
+
def __getitem__(self, idx):
|
| 22 |
+
img, target = super(CocoDetection, self).__getitem__(idx)
|
| 23 |
+
image_id = self.ids[idx]
|
| 24 |
+
target = {'image_id': image_id, 'annotations': target}
|
| 25 |
+
img, target = self.prepare(img, target)
|
| 26 |
+
if self._transforms is not None:
|
| 27 |
+
img, target = self._transforms(img, target)
|
| 28 |
+
return img, target
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def convert_coco_poly_to_mask(segmentations, height, width):
|
| 32 |
+
masks = []
|
| 33 |
+
for polygons in segmentations:
|
| 34 |
+
rles = coco_mask.frPyObjects(polygons, height, width)
|
| 35 |
+
mask = coco_mask.decode(rles)
|
| 36 |
+
if len(mask.shape) < 3:
|
| 37 |
+
mask = mask[..., None]
|
| 38 |
+
mask = torch.as_tensor(mask, dtype=torch.uint8)
|
| 39 |
+
mask = mask.any(dim=2)
|
| 40 |
+
masks.append(mask)
|
| 41 |
+
if masks:
|
| 42 |
+
masks = torch.stack(masks, dim=0)
|
| 43 |
+
else:
|
| 44 |
+
masks = torch.zeros((0, height, width), dtype=torch.uint8)
|
| 45 |
+
return masks
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class ConvertCocoPolysToMask(object):
|
| 49 |
+
def __init__(self, return_masks=False):
|
| 50 |
+
self.return_masks = return_masks
|
| 51 |
+
|
| 52 |
+
def __call__(self, image, target):
|
| 53 |
+
w, h = image.size
|
| 54 |
+
|
| 55 |
+
image_id = target["image_id"]
|
| 56 |
+
image_id = torch.tensor([image_id])
|
| 57 |
+
|
| 58 |
+
anno = target["annotations"]
|
| 59 |
+
|
| 60 |
+
anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]
|
| 61 |
+
|
| 62 |
+
boxes = [obj["bbox"] for obj in anno]
|
| 63 |
+
# guard against no boxes via resizing
|
| 64 |
+
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
|
| 65 |
+
boxes[:, 2:] += boxes[:, :2] # (x1, y1, w, h) -> (x1, y1, x2, y2)
|
| 66 |
+
boxes[:, 0::2].clamp_(min=0, max=w)
|
| 67 |
+
boxes[:, 1::2].clamp_(min=0, max=h)
|
| 68 |
+
|
| 69 |
+
classes = [obj["category_id"] for obj in anno]
|
| 70 |
+
classes = torch.tensor(classes, dtype=torch.int64)
|
| 71 |
+
|
| 72 |
+
if self.return_masks:
|
| 73 |
+
segmentations = [obj["segmentation"] for obj in anno]
|
| 74 |
+
masks = convert_coco_poly_to_mask(segmentations, h, w)
|
| 75 |
+
|
| 76 |
+
keypoints = None
|
| 77 |
+
if anno and "keypoints" in anno[0]:
|
| 78 |
+
keypoints = [obj["keypoints"] for obj in anno]
|
| 79 |
+
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
|
| 80 |
+
num_keypoints = keypoints.shape[0]
|
| 81 |
+
if num_keypoints:
|
| 82 |
+
keypoints = keypoints.view(num_keypoints, -1, 3)
|
| 83 |
+
|
| 84 |
+
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
| 85 |
+
boxes = boxes[keep]
|
| 86 |
+
classes = classes[keep]
|
| 87 |
+
if self.return_masks:
|
| 88 |
+
masks = masks[keep]
|
| 89 |
+
if keypoints is not None:
|
| 90 |
+
keypoints = keypoints[keep]
|
| 91 |
+
|
| 92 |
+
target = {}
|
| 93 |
+
target["boxes"] = boxes
|
| 94 |
+
target["labels"] = classes
|
| 95 |
+
if self.return_masks:
|
| 96 |
+
target["masks"] = masks
|
| 97 |
+
target["image_id"] = image_id
|
| 98 |
+
if keypoints is not None:
|
| 99 |
+
target["keypoints"] = keypoints
|
| 100 |
+
|
| 101 |
+
# for conversion to coco api
|
| 102 |
+
area = torch.tensor([obj["area"] for obj in anno])
|
| 103 |
+
iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
|
| 104 |
+
target["area"] = area[keep]
|
| 105 |
+
target["iscrowd"] = iscrowd[keep]
|
| 106 |
+
|
| 107 |
+
target["orig_size"] = torch.as_tensor([int(h), int(w)])
|
| 108 |
+
target["size"] = torch.as_tensor([int(h), int(w)])
|
| 109 |
+
|
| 110 |
+
return image, target
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def make_coco_transforms(image_set):
|
| 114 |
+
|
| 115 |
+
normalize = T.Compose([
|
| 116 |
+
T.ToTensor(),
|
| 117 |
+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 118 |
+
])
|
| 119 |
+
|
| 120 |
+
scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
|
| 121 |
+
|
| 122 |
+
if image_set == 'train':
|
| 123 |
+
return T.Compose([
|
| 124 |
+
T.RandomHorizontalFlip(),
|
| 125 |
+
T.RandomSelect(
|
| 126 |
+
T.RandomResize(scales, max_size=1333),
|
| 127 |
+
T.Compose([
|
| 128 |
+
T.RandomResize([400, 500, 600]),
|
| 129 |
+
T.RandomSizeCrop(384, 600),
|
| 130 |
+
T.RandomResize(scales, max_size=1333),
|
| 131 |
+
])
|
| 132 |
+
),
|
| 133 |
+
normalize,
|
| 134 |
+
])
|
| 135 |
+
|
| 136 |
+
if image_set == 'val':
|
| 137 |
+
return T.Compose([
|
| 138 |
+
T.RandomResize([800], max_size=1333),
|
| 139 |
+
normalize,
|
| 140 |
+
])
|
| 141 |
+
|
| 142 |
+
raise ValueError(f'unknown {image_set}')
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def build(image_set, args):
|
| 146 |
+
root = Path(args.data_path)
|
| 147 |
+
assert root.exists(), f'provided COCO path {root} does not exist'
|
| 148 |
+
mode = 'instances'
|
| 149 |
+
PATHS = {
|
| 150 |
+
"train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'),
|
| 151 |
+
"val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'),
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
img_folder, ann_file = PATHS[image_set]
|
| 155 |
+
dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks)
|
| 156 |
+
return dataset
|
hotr/data/datasets/hico.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# HOTR official code : hotr/data/datasets/hico.py
|
| 3 |
+
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
|
| 4 |
+
# ------------------------------------------------------------------------
|
| 5 |
+
# Modified from QPIC (https://github.com/hitachi-rd-cv/qpic)
|
| 6 |
+
# Copyright (c) Hitachi, Ltd. All Rights Reserved.
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import json
|
| 12 |
+
from collections import defaultdict
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.utils.data
|
| 17 |
+
import torchvision
|
| 18 |
+
|
| 19 |
+
from hotr.data.datasets import builtin_meta
|
| 20 |
+
import hotr.data.transforms.transforms as T
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class HICODetection(torch.utils.data.Dataset):
|
| 24 |
+
def __init__(self, img_set, img_folder, anno_file, action_list_file, transforms, num_queries):
|
| 25 |
+
self.img_set = img_set
|
| 26 |
+
self.img_folder = img_folder
|
| 27 |
+
with open(anno_file, 'r') as f:
|
| 28 |
+
self.annotations = json.load(f)
|
| 29 |
+
with open(action_list_file, 'r') as f:
|
| 30 |
+
self.action_lines = f.readlines()
|
| 31 |
+
self._transforms = transforms
|
| 32 |
+
self.num_queries = num_queries
|
| 33 |
+
self.get_metadata()
|
| 34 |
+
|
| 35 |
+
if img_set == 'train':
|
| 36 |
+
self.ids = []
|
| 37 |
+
for idx, img_anno in enumerate(self.annotations):
|
| 38 |
+
for hoi in img_anno['hoi_annotation']:
|
| 39 |
+
if hoi['subject_id'] >= len(img_anno['annotations']) or hoi['object_id'] >= len(img_anno['annotations']):
|
| 40 |
+
break
|
| 41 |
+
else:
|
| 42 |
+
self.ids.append(idx)
|
| 43 |
+
else:
|
| 44 |
+
self.ids = list(range(len(self.annotations)))
|
| 45 |
+
|
| 46 |
+
############################################################################
|
| 47 |
+
# Number Method
|
| 48 |
+
############################################################################
|
| 49 |
+
def get_metadata(self):
|
| 50 |
+
meta = builtin_meta._get_coco_instances_meta()
|
| 51 |
+
self.COCO_CLASSES = meta['coco_classes']
|
| 52 |
+
self._valid_obj_ids = [id for id in meta['thing_dataset_id_to_contiguous_id'].keys()]
|
| 53 |
+
self._valid_verb_ids, self._valid_verb_names = [], []
|
| 54 |
+
for action_line in self.action_lines[2:]:
|
| 55 |
+
act_id, act_name = action_line.split()
|
| 56 |
+
self._valid_verb_ids.append(int(act_id))
|
| 57 |
+
self._valid_verb_names.append(act_name)
|
| 58 |
+
|
| 59 |
+
def get_valid_obj_ids(self):
|
| 60 |
+
return self._valid_obj_ids
|
| 61 |
+
|
| 62 |
+
def get_actions(self):
|
| 63 |
+
return self._valid_verb_names
|
| 64 |
+
|
| 65 |
+
def num_category(self):
|
| 66 |
+
return len(self.COCO_CLASSES)
|
| 67 |
+
|
| 68 |
+
def num_action(self):
|
| 69 |
+
return len(self._valid_verb_ids)
|
| 70 |
+
############################################################################
|
| 71 |
+
|
| 72 |
+
def __len__(self):
|
| 73 |
+
return len(self.ids)
|
| 74 |
+
|
| 75 |
+
def __getitem__(self, idx):
|
| 76 |
+
img_anno = self.annotations[self.ids[idx]]
|
| 77 |
+
|
| 78 |
+
img = Image.open(self.img_folder / img_anno['file_name']).convert('RGB')
|
| 79 |
+
w, h = img.size
|
| 80 |
+
|
| 81 |
+
# cut out the GTs that exceed the number of object queries
|
| 82 |
+
if self.img_set == 'train' and len(img_anno['annotations']) > self.num_queries:
|
| 83 |
+
img_anno['annotations'] = img_anno['annotations'][:self.num_queries]
|
| 84 |
+
|
| 85 |
+
boxes = [obj['bbox'] for obj in img_anno['annotations']]
|
| 86 |
+
# guard against no boxes via resizing
|
| 87 |
+
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
|
| 88 |
+
|
| 89 |
+
if self.img_set == 'train':
|
| 90 |
+
# Add index for confirming which boxes are kept after image transformation
|
| 91 |
+
classes = [(i, self._valid_obj_ids.index(obj['category_id'])) for i, obj in enumerate(img_anno['annotations'])]
|
| 92 |
+
else:
|
| 93 |
+
classes = [self._valid_obj_ids.index(obj['category_id']) for obj in img_anno['annotations']]
|
| 94 |
+
classes = torch.tensor(classes, dtype=torch.int64)
|
| 95 |
+
|
| 96 |
+
target = {}
|
| 97 |
+
target['orig_size'] = torch.as_tensor([int(h), int(w)])
|
| 98 |
+
target['size'] = torch.as_tensor([int(h), int(w)])
|
| 99 |
+
if self.img_set == 'train':
|
| 100 |
+
boxes[:, 0::2].clamp_(min=0, max=w)
|
| 101 |
+
boxes[:, 1::2].clamp_(min=0, max=h)
|
| 102 |
+
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
| 103 |
+
boxes = boxes[keep]
|
| 104 |
+
classes = classes[keep]
|
| 105 |
+
|
| 106 |
+
target['boxes'] = boxes
|
| 107 |
+
target['labels'] = classes
|
| 108 |
+
target['iscrowd'] = torch.tensor([0 for _ in range(boxes.shape[0])])
|
| 109 |
+
target['area'] = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
| 110 |
+
|
| 111 |
+
if self._transforms is not None:
|
| 112 |
+
img, target = self._transforms(img, target)
|
| 113 |
+
|
| 114 |
+
kept_box_indices = [label[0] for label in target['labels']]
|
| 115 |
+
|
| 116 |
+
target['labels'] = target['labels'][:, 1]
|
| 117 |
+
|
| 118 |
+
obj_labels, verb_labels, sub_boxes, obj_boxes = [], [], [], []
|
| 119 |
+
sub_obj_pairs = []
|
| 120 |
+
for hoi in img_anno['hoi_annotation']:
|
| 121 |
+
if hoi['subject_id'] not in kept_box_indices or hoi['object_id'] not in kept_box_indices:
|
| 122 |
+
continue
|
| 123 |
+
sub_obj_pair = (hoi['subject_id'], hoi['object_id'])
|
| 124 |
+
if sub_obj_pair in sub_obj_pairs:
|
| 125 |
+
verb_labels[sub_obj_pairs.index(sub_obj_pair)][self._valid_verb_ids.index(hoi['category_id'])] = 1
|
| 126 |
+
else:
|
| 127 |
+
sub_obj_pairs.append(sub_obj_pair)
|
| 128 |
+
obj_labels.append(target['labels'][kept_box_indices.index(hoi['object_id'])])
|
| 129 |
+
verb_label = [0 for _ in range(len(self._valid_verb_ids))]
|
| 130 |
+
verb_label[self._valid_verb_ids.index(hoi['category_id'])] = 1
|
| 131 |
+
sub_box = target['boxes'][kept_box_indices.index(hoi['subject_id'])]
|
| 132 |
+
obj_box = target['boxes'][kept_box_indices.index(hoi['object_id'])]
|
| 133 |
+
verb_labels.append(verb_label)
|
| 134 |
+
sub_boxes.append(sub_box)
|
| 135 |
+
obj_boxes.append(obj_box)
|
| 136 |
+
if len(sub_obj_pairs) == 0:
|
| 137 |
+
target['pair_targets'] = torch.zeros((0,), dtype=torch.int64)
|
| 138 |
+
target['pair_actions'] = torch.zeros((0, len(self._valid_verb_ids)), dtype=torch.float32)
|
| 139 |
+
target['sub_boxes'] = torch.zeros((0, 4), dtype=torch.float32)
|
| 140 |
+
target['obj_boxes'] = torch.zeros((0, 4), dtype=torch.float32)
|
| 141 |
+
else:
|
| 142 |
+
target['pair_targets'] = torch.stack(obj_labels)
|
| 143 |
+
target['pair_actions'] = torch.as_tensor(verb_labels, dtype=torch.float32)
|
| 144 |
+
target['sub_boxes'] = torch.stack(sub_boxes)
|
| 145 |
+
target['obj_boxes'] = torch.stack(obj_boxes)
|
| 146 |
+
else:
|
| 147 |
+
target['boxes'] = boxes
|
| 148 |
+
target['labels'] = classes
|
| 149 |
+
target['id'] = idx
|
| 150 |
+
|
| 151 |
+
if self._transforms is not None:
|
| 152 |
+
img, _ = self._transforms(img, None)
|
| 153 |
+
|
| 154 |
+
hois = []
|
| 155 |
+
for hoi in img_anno['hoi_annotation']:
|
| 156 |
+
hois.append((hoi['subject_id'], hoi['object_id'], self._valid_verb_ids.index(hoi['category_id'])))
|
| 157 |
+
target['hois'] = torch.as_tensor(hois, dtype=torch.int64)
|
| 158 |
+
|
| 159 |
+
return img, target
|
| 160 |
+
|
| 161 |
+
def set_rare_hois(self, anno_file):
|
| 162 |
+
with open(anno_file, 'r') as f:
|
| 163 |
+
annotations = json.load(f)
|
| 164 |
+
|
| 165 |
+
counts = defaultdict(lambda: 0)
|
| 166 |
+
for img_anno in annotations:
|
| 167 |
+
hois = img_anno['hoi_annotation']
|
| 168 |
+
bboxes = img_anno['annotations']
|
| 169 |
+
for hoi in hois:
|
| 170 |
+
triplet = (self._valid_obj_ids.index(bboxes[hoi['subject_id']]['category_id']),
|
| 171 |
+
self._valid_obj_ids.index(bboxes[hoi['object_id']]['category_id']),
|
| 172 |
+
self._valid_verb_ids.index(hoi['category_id']))
|
| 173 |
+
counts[triplet] += 1
|
| 174 |
+
self.rare_triplets = []
|
| 175 |
+
self.non_rare_triplets = []
|
| 176 |
+
for triplet, count in counts.items():
|
| 177 |
+
if count < 10:
|
| 178 |
+
self.rare_triplets.append(triplet)
|
| 179 |
+
else:
|
| 180 |
+
self.non_rare_triplets.append(triplet)
|
| 181 |
+
|
| 182 |
+
def load_correct_mat(self, path):
|
| 183 |
+
self.correct_mat = np.load(path)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# Add color jitter to coco transforms
|
| 187 |
+
def make_hico_transforms(image_set):
|
| 188 |
+
|
| 189 |
+
normalize = T.Compose([
|
| 190 |
+
T.ToTensor(),
|
| 191 |
+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 192 |
+
])
|
| 193 |
+
|
| 194 |
+
scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
|
| 195 |
+
|
| 196 |
+
if image_set == 'train':
|
| 197 |
+
return T.Compose([
|
| 198 |
+
T.RandomHorizontalFlip(),
|
| 199 |
+
T.ColorJitter(.4, .4, .4),
|
| 200 |
+
T.RandomSelect(
|
| 201 |
+
T.RandomResize(scales, max_size=1333),
|
| 202 |
+
T.Compose([
|
| 203 |
+
T.RandomResize([400, 500, 600]),
|
| 204 |
+
T.RandomSizeCrop(384, 600),
|
| 205 |
+
T.RandomResize(scales, max_size=1333),
|
| 206 |
+
])
|
| 207 |
+
),
|
| 208 |
+
normalize,
|
| 209 |
+
])
|
| 210 |
+
|
| 211 |
+
if image_set == 'val':
|
| 212 |
+
return T.Compose([
|
| 213 |
+
T.RandomResize([800], max_size=1333),
|
| 214 |
+
normalize,
|
| 215 |
+
])
|
| 216 |
+
|
| 217 |
+
if image_set == 'test':
|
| 218 |
+
return T.Compose([
|
| 219 |
+
T.RandomResize([800], max_size=1333),
|
| 220 |
+
normalize,
|
| 221 |
+
])
|
| 222 |
+
|
| 223 |
+
raise ValueError(f'unknown {image_set}')
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def build(image_set, args):
|
| 227 |
+
root = Path(args.data_path)
|
| 228 |
+
assert root.exists(), f'provided HOI path {root} does not exist'
|
| 229 |
+
PATHS = {
|
| 230 |
+
'train': (root / 'images' / 'train2015', root / 'annotations' / 'trainval_hico.json'),
|
| 231 |
+
'val': (root / 'images' / 'test2015', root / 'annotations' / 'test_hico.json'),
|
| 232 |
+
'test': (root / 'images' / 'test2015', root / 'annotations' / 'test_hico.json')
|
| 233 |
+
}
|
| 234 |
+
CORRECT_MAT_PATH = root / 'annotations' / 'corre_hico.npy'
|
| 235 |
+
action_list_file = root / 'list_action.txt'
|
| 236 |
+
|
| 237 |
+
img_folder, anno_file = PATHS[image_set]
|
| 238 |
+
dataset = HICODetection(image_set, img_folder, anno_file, action_list_file, transforms=make_hico_transforms(image_set),
|
| 239 |
+
num_queries=args.num_queries)
|
| 240 |
+
if image_set == 'val' or image_set == 'test':
|
| 241 |
+
dataset.set_rare_hois(PATHS['train'][1])
|
| 242 |
+
dataset.load_correct_mat(CORRECT_MAT_PATH)
|
| 243 |
+
return dataset
|
hotr/data/datasets/vcoco.py
ADDED
|
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Kakaobrain, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
"""
|
| 3 |
+
V-COCO dataset which returns image_id for evaluation.
|
| 4 |
+
"""
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import os
|
| 9 |
+
import numpy as np
|
| 10 |
+
import json
|
| 11 |
+
import torch
|
| 12 |
+
import torch.utils.data
|
| 13 |
+
import torchvision
|
| 14 |
+
|
| 15 |
+
from torch.utils.data import Dataset
|
| 16 |
+
from pycocotools.coco import COCO
|
| 17 |
+
from pycocotools import mask as coco_mask
|
| 18 |
+
|
| 19 |
+
from hotr.data.datasets import builtin_meta
|
| 20 |
+
import hotr.data.transforms.transforms as T
|
| 21 |
+
|
| 22 |
+
class VCocoDetection(Dataset):
|
| 23 |
+
def __init__(self,
|
| 24 |
+
img_folder,
|
| 25 |
+
ann_file,
|
| 26 |
+
all_file,
|
| 27 |
+
filter_empty_gt=True,
|
| 28 |
+
transforms=None):
|
| 29 |
+
self.img_folder = img_folder
|
| 30 |
+
self.file_meta = dict()
|
| 31 |
+
self._transforms = transforms
|
| 32 |
+
|
| 33 |
+
self.ann_file = ann_file
|
| 34 |
+
self.all_file = all_file
|
| 35 |
+
self.filter_empty_gt = filter_empty_gt
|
| 36 |
+
|
| 37 |
+
# COCO initialize
|
| 38 |
+
self.coco = COCO(self.all_file)
|
| 39 |
+
self.COCO_CLASSES = builtin_meta._get_coco_instances_meta()['coco_classes']
|
| 40 |
+
self.file_meta['coco_classes'] = self.COCO_CLASSES
|
| 41 |
+
|
| 42 |
+
# Load V-COCO Dataset
|
| 43 |
+
self.vcoco_all = self.load_vcoco(self.ann_file)
|
| 44 |
+
|
| 45 |
+
# Save COCO annotation data
|
| 46 |
+
self.image_ids = sorted(list(set(self.vcoco_all[0]['image_id'].reshape(-1))))
|
| 47 |
+
|
| 48 |
+
# Filter Data
|
| 49 |
+
if filter_empty_gt:
|
| 50 |
+
self.filter_image_id()
|
| 51 |
+
self.img_infos = self.load_annotations()
|
| 52 |
+
|
| 53 |
+
# Refine Data
|
| 54 |
+
self.save_action_name()
|
| 55 |
+
self.mapping_inst_action_to_action()
|
| 56 |
+
self.load_subobj_classes()
|
| 57 |
+
self.CLASSES = self.act_list
|
| 58 |
+
|
| 59 |
+
############################################################################
|
| 60 |
+
# Load V-COCO Dataset
|
| 61 |
+
############################################################################
|
| 62 |
+
def load_vcoco(self, dir_name=None):
|
| 63 |
+
with open(dir_name, 'rt') as f:
|
| 64 |
+
vsrl_data = json.load(f)
|
| 65 |
+
|
| 66 |
+
for i in range(len(vsrl_data)):
|
| 67 |
+
vsrl_data[i]['role_object_id'] = np.array(vsrl_data[i]['role_object_id']).reshape((len(vsrl_data[i]['role_name']),-1)).T
|
| 68 |
+
for j in ['ann_id', 'label', 'image_id']:
|
| 69 |
+
vsrl_data[i][j] = np.array(vsrl_data[i][j]).reshape((-1,1))
|
| 70 |
+
|
| 71 |
+
return vsrl_data
|
| 72 |
+
|
| 73 |
+
############################################################################
|
| 74 |
+
# Refine Data
|
| 75 |
+
############################################################################
|
| 76 |
+
def save_action_name(self):
|
| 77 |
+
self.inst_act_list = list()
|
| 78 |
+
self.act_list = list()
|
| 79 |
+
|
| 80 |
+
# add instance action human classes
|
| 81 |
+
self.num_subject_act = 0
|
| 82 |
+
for vcoco in self.vcoco_all:
|
| 83 |
+
self.inst_act_list.append('human_' + vcoco['action_name'])
|
| 84 |
+
self.num_subject_act += 1
|
| 85 |
+
|
| 86 |
+
# add instance action object classes
|
| 87 |
+
for vcoco in self.vcoco_all:
|
| 88 |
+
if len(vcoco['role_name']) == 3:
|
| 89 |
+
self.inst_act_list.append('object_' + vcoco['action_name']+'_'+vcoco['role_name'][1])
|
| 90 |
+
self.inst_act_list.append('object_' + vcoco['action_name']+'_'+vcoco['role_name'][2])
|
| 91 |
+
elif len(vcoco['role_name']) < 2:
|
| 92 |
+
continue
|
| 93 |
+
else:
|
| 94 |
+
self.inst_act_list.append('object_' + vcoco['action_name']+'_'+vcoco['role_name'][-1]) # when only two roles
|
| 95 |
+
|
| 96 |
+
# add action classes
|
| 97 |
+
for vcoco in self.vcoco_all:
|
| 98 |
+
if len(vcoco['role_name']) == 3:
|
| 99 |
+
self.act_list.append(vcoco['action_name']+'_'+vcoco['role_name'][1])
|
| 100 |
+
self.act_list.append(vcoco['action_name']+'_'+vcoco['role_name'][2])
|
| 101 |
+
else:
|
| 102 |
+
self.act_list.append(vcoco['action_name']+'_'+vcoco['role_name'][-1])
|
| 103 |
+
|
| 104 |
+
# add to meta
|
| 105 |
+
self.file_meta['action_classes'] = self.act_list
|
| 106 |
+
|
| 107 |
+
def mapping_inst_action_to_action(self):
|
| 108 |
+
sub_idx = 0
|
| 109 |
+
obj_idx = self.num_subject_act
|
| 110 |
+
|
| 111 |
+
self.sub_label_to_action = list()
|
| 112 |
+
self.obj_label_to_action = list()
|
| 113 |
+
|
| 114 |
+
for vcoco in self.vcoco_all:
|
| 115 |
+
role_name = vcoco['role_name']
|
| 116 |
+
|
| 117 |
+
self.sub_label_to_action.append(sub_idx)
|
| 118 |
+
if len(role_name) == 3 :
|
| 119 |
+
self.sub_label_to_action.append(sub_idx)
|
| 120 |
+
self.obj_label_to_action.append(obj_idx)
|
| 121 |
+
self.obj_label_to_action.append(obj_idx+1)
|
| 122 |
+
obj_idx += 2
|
| 123 |
+
elif len(role_name) == 2:
|
| 124 |
+
self.obj_label_to_action.append(obj_idx)
|
| 125 |
+
obj_idx += 1
|
| 126 |
+
else:
|
| 127 |
+
self.obj_label_to_action.append(0)
|
| 128 |
+
|
| 129 |
+
sub_idx += 1
|
| 130 |
+
|
| 131 |
+
def load_subobj_classes(self):
|
| 132 |
+
self.vcoco_labels = dict()
|
| 133 |
+
for img in self.image_ids:
|
| 134 |
+
self.vcoco_labels[img] = dict()
|
| 135 |
+
self.vcoco_labels[img]['boxes'] = np.empty((0, 4), dtype=np.float32)
|
| 136 |
+
self.vcoco_labels[img]['categories'] = np.empty((0), dtype=np.int32)
|
| 137 |
+
|
| 138 |
+
ann_ids = self.coco.getAnnIds(imgIds=img, iscrowd=None)
|
| 139 |
+
objs = self.coco.loadAnns(ann_ids)
|
| 140 |
+
|
| 141 |
+
valid_ann_ids = []
|
| 142 |
+
|
| 143 |
+
for i, obj in enumerate(objs):
|
| 144 |
+
if 'ignore' in obj and obj['ignore'] == 1: continue
|
| 145 |
+
|
| 146 |
+
x1 = obj['bbox'][0]
|
| 147 |
+
y1 = obj['bbox'][1]
|
| 148 |
+
x2 = x1 + np.maximum(0., obj['bbox'][2] - 1.)
|
| 149 |
+
y2 = y1 + np.maximum(0., obj['bbox'][3] - 1.)
|
| 150 |
+
|
| 151 |
+
if obj['area'] > 0 and x2 > x1 and y2 > y1:
|
| 152 |
+
bbox = np.array([x1, y1, x2, y2]).reshape(1, -1)
|
| 153 |
+
cls = obj['category_id']
|
| 154 |
+
self.vcoco_labels[img]['boxes'] = np.concatenate([self.vcoco_labels[img]['boxes'], bbox], axis=0)
|
| 155 |
+
self.vcoco_labels[img]['categories'] = np.concatenate([self.vcoco_labels[img]['categories'], [cls]], axis=0)
|
| 156 |
+
|
| 157 |
+
valid_ann_ids.append(ann_ids[i])
|
| 158 |
+
|
| 159 |
+
num_valid_objs = len(valid_ann_ids)
|
| 160 |
+
|
| 161 |
+
self.vcoco_labels[img]['agent_actions'] = -np.ones((num_valid_objs, self.num_action()), dtype=np.int32)
|
| 162 |
+
self.vcoco_labels[img]['obj_actions'] = np.zeros((num_valid_objs, self.num_action()), dtype=np.int32)
|
| 163 |
+
self.vcoco_labels[img]['role_id'] = -np.ones((num_valid_objs, self.num_action()), dtype=np.int32)
|
| 164 |
+
|
| 165 |
+
for ix, ann_id in enumerate(valid_ann_ids):
|
| 166 |
+
in_vcoco = np.where(self.vcoco_all[0]['ann_id'] == ann_id)[0]
|
| 167 |
+
if in_vcoco.size > 0:
|
| 168 |
+
self.vcoco_labels[img]['agent_actions'][ix, :] = 0
|
| 169 |
+
|
| 170 |
+
agent_act_id = 0
|
| 171 |
+
obj_act_id = -1
|
| 172 |
+
for i, x in enumerate(self.vcoco_all):
|
| 173 |
+
has_label = np.where(np.logical_and(x['ann_id'] == ann_id, x['label'] == 1))[0]
|
| 174 |
+
if has_label.size > 0:
|
| 175 |
+
assert has_label.size == 1
|
| 176 |
+
rids = x['role_object_id'][has_label]
|
| 177 |
+
|
| 178 |
+
if rids.shape[1] == 3:
|
| 179 |
+
self.vcoco_labels[img]['agent_actions'][ix, agent_act_id] = 1
|
| 180 |
+
self.vcoco_labels[img]['agent_actions'][ix, agent_act_id+1] = 1
|
| 181 |
+
agent_act_id += 2
|
| 182 |
+
else:
|
| 183 |
+
self.vcoco_labels[img]['agent_actions'][ix, agent_act_id] = 1
|
| 184 |
+
agent_act_id += 1
|
| 185 |
+
if rids.shape[1] == 1 : obj_act_id += 1
|
| 186 |
+
|
| 187 |
+
for j in range(1, rids.shape[1]):
|
| 188 |
+
obj_act_id += 1
|
| 189 |
+
if rids[0, j] == 0: continue # no role
|
| 190 |
+
aid = np.where(valid_ann_ids == rids[0, j])[0]
|
| 191 |
+
|
| 192 |
+
self.vcoco_labels[img]['role_id'][ix, obj_act_id] = aid
|
| 193 |
+
self.vcoco_labels[img]['obj_actions'][aid, obj_act_id] = 1
|
| 194 |
+
|
| 195 |
+
else:
|
| 196 |
+
rids = x['role_object_id'][0]
|
| 197 |
+
if rids.shape[0] == 3:
|
| 198 |
+
agent_act_id += 2
|
| 199 |
+
obj_act_id += 2
|
| 200 |
+
else:
|
| 201 |
+
agent_act_id += 1
|
| 202 |
+
obj_act_id += 1
|
| 203 |
+
|
| 204 |
+
############################################################################
|
| 205 |
+
# Annotation Loader
|
| 206 |
+
############################################################################
|
| 207 |
+
# >>> 1. instance
|
| 208 |
+
def load_instance_annotations(self, image_index):
|
| 209 |
+
num_ann = self.vcoco_labels[image_index]['boxes'].shape[0]
|
| 210 |
+
inst_action = np.zeros((num_ann, self.num_inst_action()), np.int)
|
| 211 |
+
inst_bbox = np.zeros((num_ann, 4), dtype=np.float32)
|
| 212 |
+
inst_category = np.zeros((num_ann, ), dtype=np.int)
|
| 213 |
+
|
| 214 |
+
for idx in range(num_ann):
|
| 215 |
+
inst_bbox[idx] = self.vcoco_labels[image_index]['boxes'][idx]
|
| 216 |
+
inst_category[idx]= self.vcoco_labels[image_index]['categories'][idx] #+ 1 # category 1 ~ 81
|
| 217 |
+
|
| 218 |
+
if inst_category[idx] == 1:
|
| 219 |
+
act = self.vcoco_labels[image_index]['agent_actions'][idx]
|
| 220 |
+
inst_action[idx, :self.num_subject_act] = act[np.unique(self.sub_label_to_action, return_index=True)[1]]
|
| 221 |
+
|
| 222 |
+
# when person is the obj
|
| 223 |
+
act = self.vcoco_labels[image_index]['obj_actions'][idx] # when person is the obj
|
| 224 |
+
if act.any():
|
| 225 |
+
inst_action[idx, self.num_subject_act:] = act[np.nonzero(self.obj_label_to_action)[0]]
|
| 226 |
+
if inst_action[idx, :self.num_subject_act].sum(axis=-1) < 0:
|
| 227 |
+
inst_action[idx, :self.num_subject_act] = 0
|
| 228 |
+
else:
|
| 229 |
+
act = self.vcoco_labels[image_index]['obj_actions'][idx]
|
| 230 |
+
inst_action[idx, self.num_subject_act:] = act[np.nonzero(self.obj_label_to_action)[0]]
|
| 231 |
+
|
| 232 |
+
# >>> For Objects that are in COCO but not in V-COCO,
|
| 233 |
+
# >>> Human -> [-1 * 26, 0 * 25]
|
| 234 |
+
# >>> Object -> [0 * 51]
|
| 235 |
+
# >>> Don't return anything for actions with max 0 or max -1
|
| 236 |
+
max_val = inst_action.max(axis=1)
|
| 237 |
+
if (max_val > 0).sum() == 0:
|
| 238 |
+
print(f"No Annotations for {image_index}")
|
| 239 |
+
print(inst_action)
|
| 240 |
+
print(self.vcoco_labels[image_index]['agent_actions'][idx])
|
| 241 |
+
print(self.vcoco_labels[image_index]['obj_actions'][idx])
|
| 242 |
+
|
| 243 |
+
return inst_bbox[max_val > 0], inst_category[max_val > 0], inst_action[max_val > 0]
|
| 244 |
+
|
| 245 |
+
# >>> 2. pair
|
| 246 |
+
def load_pair_annotations(self, image_index):
|
| 247 |
+
num_ann = self.vcoco_labels[image_index]['boxes'].shape[0]
|
| 248 |
+
pair_action = np.zeros((0, self.num_action()), np.int)
|
| 249 |
+
pair_bbox = np.zeros((0, 8), dtype=np.float32)
|
| 250 |
+
pair_target = np.zeros((0, ), dtype=np.int)
|
| 251 |
+
|
| 252 |
+
for idx in range(num_ann):
|
| 253 |
+
h_box = self.vcoco_labels[image_index]['boxes'][idx]
|
| 254 |
+
h_cat = self.vcoco_labels[image_index]['categories'][idx]
|
| 255 |
+
if h_cat != 1 : continue # human_id = 1
|
| 256 |
+
|
| 257 |
+
h_act = self.vcoco_labels[image_index]['agent_actions'][idx]
|
| 258 |
+
if np.any((h_act==-1)) : continue
|
| 259 |
+
|
| 260 |
+
o_act = dict()
|
| 261 |
+
for aid in range(self.num_action()):
|
| 262 |
+
if h_act[aid] == 0 : continue
|
| 263 |
+
o_id = self.vcoco_labels[image_index]['role_id'][idx, aid]
|
| 264 |
+
if o_id not in o_act : o_act[o_id] = list()
|
| 265 |
+
o_act[o_id].append(aid)
|
| 266 |
+
|
| 267 |
+
for o_id in o_act.keys():
|
| 268 |
+
if o_id == -1:
|
| 269 |
+
o_box = -np.ones((4, ))
|
| 270 |
+
o_cat = -1 # target is background
|
| 271 |
+
else:
|
| 272 |
+
o_box = self.vcoco_labels[image_index]['boxes'][o_id]
|
| 273 |
+
o_cat = self.vcoco_labels[image_index]['categories'][o_id] # category 0 ~ 80
|
| 274 |
+
|
| 275 |
+
box = np.concatenate([h_box, o_box]).astype(np.float32)
|
| 276 |
+
act = np.zeros((1, self.num_action()), np.int)
|
| 277 |
+
tar = np.zeros((1, ), np.int)
|
| 278 |
+
tar[0] = o_cat #+ 1 # category 1 ~ 81
|
| 279 |
+
for o_aid in o_act[o_id] : act[0, o_aid] = 1
|
| 280 |
+
|
| 281 |
+
pair_action = np.concatenate([pair_action, act], axis=0)
|
| 282 |
+
pair_bbox = np.concatenate([pair_bbox, np.expand_dims(box, axis=0)], axis=0)
|
| 283 |
+
pair_target = np.concatenate([pair_target, tar], axis=0)
|
| 284 |
+
|
| 285 |
+
return pair_bbox, pair_action, pair_target
|
| 286 |
+
|
| 287 |
+
# >>> 3. image infos
|
| 288 |
+
def load_annotations(self):
|
| 289 |
+
img_infos = []
|
| 290 |
+
for i in self.image_ids:
|
| 291 |
+
info = self.coco.loadImgs([i])[0]
|
| 292 |
+
img_infos.append(info)
|
| 293 |
+
return img_infos
|
| 294 |
+
|
| 295 |
+
############################################################################
|
| 296 |
+
# Check Method
|
| 297 |
+
############################################################################
|
| 298 |
+
def sum_action_ann_for_id(self, find_idx):
|
| 299 |
+
sum = 0
|
| 300 |
+
for action_ann in self.vcoco_all:
|
| 301 |
+
img_ids = action_ann['image_id']
|
| 302 |
+
img_labels = action_ann['label']
|
| 303 |
+
|
| 304 |
+
final_inds = img_ids[img_labels == 1]
|
| 305 |
+
|
| 306 |
+
if (find_idx in final_inds):
|
| 307 |
+
sum += 1
|
| 308 |
+
# sum of class-wise existence
|
| 309 |
+
return (sum > 0)
|
| 310 |
+
|
| 311 |
+
def filter_image_id(self):
|
| 312 |
+
empty_gt_list = []
|
| 313 |
+
for img_id in self.image_ids:
|
| 314 |
+
if not self.sum_action_ann_for_id(img_id):
|
| 315 |
+
empty_gt_list.append(img_id)
|
| 316 |
+
|
| 317 |
+
for remove_id in empty_gt_list:
|
| 318 |
+
rm_idx = self.image_ids.index(remove_id)
|
| 319 |
+
self.image_ids.remove(remove_id)
|
| 320 |
+
|
| 321 |
+
############################################################################
|
| 322 |
+
# Preprocessing
|
| 323 |
+
############################################################################
|
| 324 |
+
def prepare_img(self, idx):
|
| 325 |
+
img_info = self.img_infos[idx]
|
| 326 |
+
image = Image.open(os.path.join(self.img_folder, img_info['file_name'])).convert('RGB')
|
| 327 |
+
target = self.get_ann_info(idx)
|
| 328 |
+
|
| 329 |
+
w, h = image.size
|
| 330 |
+
target["orig_size"] = torch.as_tensor([int(h), int(w)])
|
| 331 |
+
target["size"] = torch.as_tensor([int(h), int(w)])
|
| 332 |
+
|
| 333 |
+
if self._transforms is not None:
|
| 334 |
+
img, target = self._transforms(image, target) # "size" gets converted here
|
| 335 |
+
|
| 336 |
+
return img, target
|
| 337 |
+
|
| 338 |
+
############################################################################
|
| 339 |
+
# Get Method
|
| 340 |
+
############################################################################
|
| 341 |
+
def __getitem__(self, idx):
|
| 342 |
+
img, target = self.prepare_img(idx)
|
| 343 |
+
return img, target
|
| 344 |
+
|
| 345 |
+
def __len__(self):
|
| 346 |
+
return len(self.image_ids)
|
| 347 |
+
|
| 348 |
+
def get_human_label_idx(self):
|
| 349 |
+
return self.sub_label_to_action
|
| 350 |
+
|
| 351 |
+
def get_object_label_idx(self):
|
| 352 |
+
return self.obj_label_to_action
|
| 353 |
+
|
| 354 |
+
def get_image_ids(self):
|
| 355 |
+
return self.image_ids
|
| 356 |
+
|
| 357 |
+
def get_categories(self):
|
| 358 |
+
return self.COCO_CLASSES
|
| 359 |
+
|
| 360 |
+
def get_inst_action(self):
|
| 361 |
+
return self.inst_act_list
|
| 362 |
+
|
| 363 |
+
def get_actions(self):
|
| 364 |
+
return self.act_list
|
| 365 |
+
|
| 366 |
+
def get_human_action(self):
|
| 367 |
+
return self.inst_act_list[:self.num_subject_act]
|
| 368 |
+
|
| 369 |
+
def get_object_action(self):
|
| 370 |
+
return self.inst_act_list[self.num_subject_act:]
|
| 371 |
+
|
| 372 |
+
def get_ann_info(self, idx):
|
| 373 |
+
img_idx = int(self.image_ids[idx])
|
| 374 |
+
|
| 375 |
+
# load each annotation
|
| 376 |
+
inst_bbox, inst_label, inst_actions = self.load_instance_annotations(img_idx)
|
| 377 |
+
pair_bbox, pair_actions, pair_targets = self.load_pair_annotations(img_idx)
|
| 378 |
+
|
| 379 |
+
sample = {
|
| 380 |
+
'image_id' : torch.tensor([img_idx]),
|
| 381 |
+
'boxes': torch.as_tensor(inst_bbox, dtype=torch.float32),
|
| 382 |
+
'labels': torch.tensor(inst_label, dtype=torch.int64),
|
| 383 |
+
'inst_actions': torch.tensor(inst_actions, dtype=torch.int64),
|
| 384 |
+
'pair_boxes': torch.as_tensor(pair_bbox, dtype=torch.float32),
|
| 385 |
+
'pair_actions': torch.tensor(pair_actions, dtype=torch.int64),
|
| 386 |
+
'pair_targets': torch.tensor(pair_targets, dtype=torch.int64),
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
return sample
|
| 390 |
+
|
| 391 |
+
############################################################################
|
| 392 |
+
# Number Method
|
| 393 |
+
############################################################################
|
| 394 |
+
def num_category(self):
|
| 395 |
+
return len(self.COCO_CLASSES)
|
| 396 |
+
|
| 397 |
+
def num_action(self):
|
| 398 |
+
return len(self.act_list)
|
| 399 |
+
|
| 400 |
+
def num_inst_action(self):
|
| 401 |
+
return len(self.inst_act_list)
|
| 402 |
+
|
| 403 |
+
def num_human_act(self):
|
| 404 |
+
return len(self.inst_act_list[:self.num_subject_act])
|
| 405 |
+
|
| 406 |
+
def num_object_act(self):
|
| 407 |
+
return len(self.inst_act_list[self.num_subject_act:])
|
| 408 |
+
|
| 409 |
+
def make_hoi_transforms(image_set):
|
| 410 |
+
normalize = T.Compose([
|
| 411 |
+
T.ToTensor(),
|
| 412 |
+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 413 |
+
])
|
| 414 |
+
|
| 415 |
+
scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
|
| 416 |
+
|
| 417 |
+
if image_set == 'train':
|
| 418 |
+
return T.Compose([
|
| 419 |
+
T.RandomHorizontalFlip(),
|
| 420 |
+
T.ColorJitter(.4, .4, .4),
|
| 421 |
+
T.RandomSelect(
|
| 422 |
+
T.RandomResize(scales, max_size=1333),
|
| 423 |
+
T.Compose([
|
| 424 |
+
T.RandomResize([400, 500, 600]),
|
| 425 |
+
T.RandomSizeCrop(384, 600),
|
| 426 |
+
T.RandomResize(scales, max_size=1333),
|
| 427 |
+
])
|
| 428 |
+
),
|
| 429 |
+
normalize,
|
| 430 |
+
])
|
| 431 |
+
|
| 432 |
+
if image_set == 'val':
|
| 433 |
+
return T.Compose([
|
| 434 |
+
T.RandomResize([800], max_size=1333),
|
| 435 |
+
normalize,
|
| 436 |
+
])
|
| 437 |
+
|
| 438 |
+
if image_set == 'test':
|
| 439 |
+
return T.Compose([
|
| 440 |
+
T.RandomResize([800], max_size=1333),
|
| 441 |
+
normalize,
|
| 442 |
+
])
|
| 443 |
+
|
| 444 |
+
raise ValueError(f'unknown {image_set}')
|
| 445 |
+
|
| 446 |
+
def build(image_set, args):
|
| 447 |
+
root = Path(args.data_path)
|
| 448 |
+
assert root.exists(), f'provided V-COCO path {root} does not exist'
|
| 449 |
+
PATHS = {
|
| 450 |
+
"train": (root / "coco/images/train2014/", root / "data/vcoco" / 'vcoco_trainval.json'),
|
| 451 |
+
"val": (root / "coco/images/val2014", root / "data/vcoco" / 'vcoco_test.json'),
|
| 452 |
+
"test": (root / "coco/images/val2014", root / "data/vcoco" / 'vcoco_test.json'),
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
img_folder, ann_file = PATHS[image_set]
|
| 456 |
+
all_file = root / "data/instances_vcoco_all_2014.json"
|
| 457 |
+
dataset = VCocoDetection(
|
| 458 |
+
img_folder = img_folder,
|
| 459 |
+
ann_file = ann_file,
|
| 460 |
+
all_file = all_file,
|
| 461 |
+
filter_empty_gt=True,
|
| 462 |
+
transforms = make_hoi_transforms(image_set)
|
| 463 |
+
)
|
| 464 |
+
dataset.file_meta['dataset_file'] = args.dataset_file
|
| 465 |
+
dataset.file_meta['image_set'] = image_set
|
| 466 |
+
|
| 467 |
+
return dataset
|
hotr/data/evaluators/coco_eval.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
"""
|
| 3 |
+
COCO evaluator that works in distributed mode.
|
| 4 |
+
Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
|
| 5 |
+
The difference is that there is less copy-pasting from pycocotools
|
| 6 |
+
in the end of the file, as python3 can suppress prints with contextlib
|
| 7 |
+
"""
|
| 8 |
+
import os
|
| 9 |
+
import contextlib
|
| 10 |
+
import copy
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from pycocotools.cocoeval import COCOeval
|
| 15 |
+
from pycocotools.coco import COCO
|
| 16 |
+
import pycocotools.mask as mask_util
|
| 17 |
+
|
| 18 |
+
from hotr.util.misc import all_gather
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class CocoEvaluator(object):
|
| 22 |
+
def __init__(self, coco_gt, iou_types):
|
| 23 |
+
assert isinstance(iou_types, (list, tuple))
|
| 24 |
+
coco_gt = copy.deepcopy(coco_gt)
|
| 25 |
+
self.coco_gt = coco_gt
|
| 26 |
+
|
| 27 |
+
self.iou_types = iou_types
|
| 28 |
+
self.coco_eval = {}
|
| 29 |
+
for iou_type in iou_types:
|
| 30 |
+
self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
|
| 31 |
+
|
| 32 |
+
self.img_ids = []
|
| 33 |
+
self.eval_imgs = {k: [] for k in iou_types}
|
| 34 |
+
|
| 35 |
+
def update(self, predictions):
|
| 36 |
+
img_ids = list(np.unique(list(predictions.keys())))
|
| 37 |
+
self.img_ids.extend(img_ids)
|
| 38 |
+
|
| 39 |
+
for iou_type in self.iou_types:
|
| 40 |
+
results = self.prepare(predictions, iou_type)
|
| 41 |
+
|
| 42 |
+
# suppress pycocotools prints
|
| 43 |
+
with open(os.devnull, 'w') as devnull:
|
| 44 |
+
with contextlib.redirect_stdout(devnull):
|
| 45 |
+
coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
|
| 46 |
+
coco_eval = self.coco_eval[iou_type]
|
| 47 |
+
|
| 48 |
+
coco_eval.cocoDt = coco_dt
|
| 49 |
+
coco_eval.params.imgIds = list(img_ids)
|
| 50 |
+
img_ids, eval_imgs = evaluate(coco_eval)
|
| 51 |
+
|
| 52 |
+
self.eval_imgs[iou_type].append(eval_imgs)
|
| 53 |
+
|
| 54 |
+
def synchronize_between_processes(self):
|
| 55 |
+
for iou_type in self.iou_types:
|
| 56 |
+
self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
|
| 57 |
+
create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
|
| 58 |
+
|
| 59 |
+
def accumulate(self):
|
| 60 |
+
for coco_eval in self.coco_eval.values():
|
| 61 |
+
coco_eval.accumulate()
|
| 62 |
+
|
| 63 |
+
def summarize(self):
|
| 64 |
+
for iou_type, coco_eval in self.coco_eval.items():
|
| 65 |
+
print("IoU metric: {}".format(iou_type))
|
| 66 |
+
coco_eval.summarize()
|
| 67 |
+
|
| 68 |
+
def prepare(self, predictions, iou_type):
|
| 69 |
+
if iou_type == "bbox":
|
| 70 |
+
return self.prepare_for_coco_detection(predictions)
|
| 71 |
+
elif iou_type == "segm":
|
| 72 |
+
return self.prepare_for_coco_segmentation(predictions)
|
| 73 |
+
elif iou_type == "keypoints":
|
| 74 |
+
return self.prepare_for_coco_keypoint(predictions)
|
| 75 |
+
else:
|
| 76 |
+
raise ValueError("Unknown iou type {}".format(iou_type))
|
| 77 |
+
|
| 78 |
+
def prepare_for_coco_detection(self, predictions):
|
| 79 |
+
coco_results = []
|
| 80 |
+
for original_id, prediction in predictions.items():
|
| 81 |
+
if len(prediction) == 0:
|
| 82 |
+
continue
|
| 83 |
+
|
| 84 |
+
boxes = prediction["boxes"]
|
| 85 |
+
boxes = convert_to_xywh(boxes).tolist()
|
| 86 |
+
scores = prediction["scores"].tolist()
|
| 87 |
+
labels = prediction["labels"].tolist()
|
| 88 |
+
|
| 89 |
+
coco_results.extend(
|
| 90 |
+
[
|
| 91 |
+
{
|
| 92 |
+
"image_id": original_id,
|
| 93 |
+
"category_id": labels[k],
|
| 94 |
+
"bbox": box,
|
| 95 |
+
"score": scores[k],
|
| 96 |
+
}
|
| 97 |
+
for k, box in enumerate(boxes)
|
| 98 |
+
]
|
| 99 |
+
)
|
| 100 |
+
return coco_results
|
| 101 |
+
|
| 102 |
+
def prepare_for_coco_segmentation(self, predictions):
|
| 103 |
+
coco_results = []
|
| 104 |
+
for original_id, prediction in predictions.items():
|
| 105 |
+
if len(prediction) == 0:
|
| 106 |
+
continue
|
| 107 |
+
|
| 108 |
+
scores = prediction["scores"]
|
| 109 |
+
labels = prediction["labels"]
|
| 110 |
+
masks = prediction["masks"]
|
| 111 |
+
|
| 112 |
+
masks = masks > 0.5
|
| 113 |
+
|
| 114 |
+
scores = prediction["scores"].tolist()
|
| 115 |
+
labels = prediction["labels"].tolist()
|
| 116 |
+
|
| 117 |
+
rles = [
|
| 118 |
+
mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
|
| 119 |
+
for mask in masks
|
| 120 |
+
]
|
| 121 |
+
for rle in rles:
|
| 122 |
+
rle["counts"] = rle["counts"].decode("utf-8")
|
| 123 |
+
|
| 124 |
+
coco_results.extend(
|
| 125 |
+
[
|
| 126 |
+
{
|
| 127 |
+
"image_id": original_id,
|
| 128 |
+
"category_id": labels[k],
|
| 129 |
+
"segmentation": rle,
|
| 130 |
+
"score": scores[k],
|
| 131 |
+
}
|
| 132 |
+
for k, rle in enumerate(rles)
|
| 133 |
+
]
|
| 134 |
+
)
|
| 135 |
+
return coco_results
|
| 136 |
+
|
| 137 |
+
def prepare_for_coco_keypoint(self, predictions):
|
| 138 |
+
coco_results = []
|
| 139 |
+
for original_id, prediction in predictions.items():
|
| 140 |
+
if len(prediction) == 0:
|
| 141 |
+
continue
|
| 142 |
+
|
| 143 |
+
boxes = prediction["boxes"]
|
| 144 |
+
boxes = convert_to_xywh(boxes).tolist()
|
| 145 |
+
scores = prediction["scores"].tolist()
|
| 146 |
+
labels = prediction["labels"].tolist()
|
| 147 |
+
keypoints = prediction["keypoints"]
|
| 148 |
+
keypoints = keypoints.flatten(start_dim=1).tolist()
|
| 149 |
+
|
| 150 |
+
coco_results.extend(
|
| 151 |
+
[
|
| 152 |
+
{
|
| 153 |
+
"image_id": original_id,
|
| 154 |
+
"category_id": labels[k],
|
| 155 |
+
'keypoints': keypoint,
|
| 156 |
+
"score": scores[k],
|
| 157 |
+
}
|
| 158 |
+
for k, keypoint in enumerate(keypoints)
|
| 159 |
+
]
|
| 160 |
+
)
|
| 161 |
+
return coco_results
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def convert_to_xywh(boxes):
|
| 165 |
+
xmin, ymin, xmax, ymax = boxes.unbind(1)
|
| 166 |
+
return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def merge(img_ids, eval_imgs):
|
| 170 |
+
all_img_ids = all_gather(img_ids)
|
| 171 |
+
all_eval_imgs = all_gather(eval_imgs)
|
| 172 |
+
|
| 173 |
+
merged_img_ids = []
|
| 174 |
+
for p in all_img_ids:
|
| 175 |
+
merged_img_ids.extend(p)
|
| 176 |
+
|
| 177 |
+
merged_eval_imgs = []
|
| 178 |
+
for p in all_eval_imgs:
|
| 179 |
+
merged_eval_imgs.append(p)
|
| 180 |
+
|
| 181 |
+
merged_img_ids = np.array(merged_img_ids)
|
| 182 |
+
merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
|
| 183 |
+
|
| 184 |
+
# keep only unique (and in sorted order) images
|
| 185 |
+
merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
|
| 186 |
+
merged_eval_imgs = merged_eval_imgs[..., idx]
|
| 187 |
+
|
| 188 |
+
return merged_img_ids, merged_eval_imgs
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
|
| 192 |
+
img_ids, eval_imgs = merge(img_ids, eval_imgs)
|
| 193 |
+
img_ids = list(img_ids)
|
| 194 |
+
eval_imgs = list(eval_imgs.flatten())
|
| 195 |
+
|
| 196 |
+
coco_eval.evalImgs = eval_imgs
|
| 197 |
+
coco_eval.params.imgIds = img_ids
|
| 198 |
+
coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
#################################################################
|
| 202 |
+
# From pycocotools, just removed the prints and fixed
|
| 203 |
+
# a Python3 bug about unicode not defined
|
| 204 |
+
#################################################################
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def evaluate(self):
|
| 208 |
+
'''
|
| 209 |
+
Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
|
| 210 |
+
:return: None
|
| 211 |
+
'''
|
| 212 |
+
# tic = time.time()
|
| 213 |
+
# print('Running per image evaluation...')
|
| 214 |
+
p = self.params
|
| 215 |
+
# add backward compatibility if useSegm is specified in params
|
| 216 |
+
if p.useSegm is not None:
|
| 217 |
+
p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
|
| 218 |
+
print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
|
| 219 |
+
# print('Evaluate annotation type *{}*'.format(p.iouType))
|
| 220 |
+
p.imgIds = list(np.unique(p.imgIds))
|
| 221 |
+
if p.useCats:
|
| 222 |
+
p.catIds = list(np.unique(p.catIds))
|
| 223 |
+
p.maxDets = sorted(p.maxDets)
|
| 224 |
+
self.params = p
|
| 225 |
+
|
| 226 |
+
self._prepare()
|
| 227 |
+
# loop through images, area range, max detection number
|
| 228 |
+
catIds = p.catIds if p.useCats else [-1]
|
| 229 |
+
|
| 230 |
+
if p.iouType == 'segm' or p.iouType == 'bbox':
|
| 231 |
+
computeIoU = self.computeIoU
|
| 232 |
+
elif p.iouType == 'keypoints':
|
| 233 |
+
computeIoU = self.computeOks
|
| 234 |
+
self.ious = {
|
| 235 |
+
(imgId, catId): computeIoU(imgId, catId)
|
| 236 |
+
for imgId in p.imgIds
|
| 237 |
+
for catId in catIds}
|
| 238 |
+
|
| 239 |
+
evaluateImg = self.evaluateImg
|
| 240 |
+
maxDet = p.maxDets[-1]
|
| 241 |
+
evalImgs = [
|
| 242 |
+
evaluateImg(imgId, catId, areaRng, maxDet)
|
| 243 |
+
for catId in catIds
|
| 244 |
+
for areaRng in p.areaRng
|
| 245 |
+
for imgId in p.imgIds
|
| 246 |
+
]
|
| 247 |
+
# this is NOT in the pycocotools code, but could be done outside
|
| 248 |
+
evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
|
| 249 |
+
self._paramsEval = copy.deepcopy(self.params)
|
| 250 |
+
# toc = time.time()
|
| 251 |
+
# print('DONE (t={:0.2f}s).'.format(toc-tic))
|
| 252 |
+
return p.imgIds, evalImgs
|
| 253 |
+
|
| 254 |
+
#################################################################
|
| 255 |
+
# end of straight copy from pycocotools, just removing the prints
|
| 256 |
+
#################################################################
|
hotr/data/evaluators/hico_eval.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# HOTR official code : hotr/data/evaluators/hico_eval.py
|
| 3 |
+
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
|
| 4 |
+
# ------------------------------------------------------------------------
|
| 5 |
+
# Modified from QPIC (https://github.com/hitachi-rd-cv/qpic)
|
| 6 |
+
# Copyright (c) Hitachi, Ltd. All Rights Reserved.
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
import numpy as np
|
| 10 |
+
from collections import defaultdict
|
| 11 |
+
|
| 12 |
+
class HICOEvaluator():
|
| 13 |
+
def __init__(self, preds, gts, rare_triplets, non_rare_triplets, correct_mat):
|
| 14 |
+
self.overlap_iou = 0.5
|
| 15 |
+
self.max_hois = 100
|
| 16 |
+
|
| 17 |
+
self.rare_triplets = rare_triplets
|
| 18 |
+
self.non_rare_triplets = non_rare_triplets
|
| 19 |
+
|
| 20 |
+
self.fp = defaultdict(list)
|
| 21 |
+
self.tp = defaultdict(list)
|
| 22 |
+
self.score = defaultdict(list)
|
| 23 |
+
self.sum_gts = defaultdict(lambda: 0)
|
| 24 |
+
self.gt_triplets = []
|
| 25 |
+
|
| 26 |
+
self.preds = []
|
| 27 |
+
for img_preds in preds:
|
| 28 |
+
img_preds = {k: v.to('cpu').numpy() for k, v in img_preds.items() if k != 'hoi_recognition_time'}
|
| 29 |
+
bboxes = [{'bbox': bbox, 'category_id': label} for bbox, label in zip(img_preds['boxes'], img_preds['labels'])]
|
| 30 |
+
hoi_scores = img_preds['verb_scores']
|
| 31 |
+
verb_labels = np.tile(np.arange(hoi_scores.shape[1]), (hoi_scores.shape[0], 1))
|
| 32 |
+
subject_ids = np.tile(img_preds['sub_ids'], (hoi_scores.shape[1], 1)).T
|
| 33 |
+
object_ids = np.tile(img_preds['obj_ids'], (hoi_scores.shape[1], 1)).T
|
| 34 |
+
|
| 35 |
+
hoi_scores = hoi_scores.ravel()
|
| 36 |
+
verb_labels = verb_labels.ravel()
|
| 37 |
+
subject_ids = subject_ids.ravel()
|
| 38 |
+
object_ids = object_ids.ravel()
|
| 39 |
+
|
| 40 |
+
if len(subject_ids) > 0:
|
| 41 |
+
object_labels = np.array([bboxes[object_id]['category_id'] for object_id in object_ids])
|
| 42 |
+
masks = correct_mat[verb_labels, object_labels]
|
| 43 |
+
hoi_scores *= masks
|
| 44 |
+
|
| 45 |
+
hois = [{'subject_id': subject_id, 'object_id': object_id, 'category_id': category_id, 'score': score} for
|
| 46 |
+
subject_id, object_id, category_id, score in zip(subject_ids, object_ids, verb_labels, hoi_scores)]
|
| 47 |
+
hois.sort(key=lambda k: (k.get('score', 0)), reverse=True)
|
| 48 |
+
hois = hois[:self.max_hois]
|
| 49 |
+
else:
|
| 50 |
+
hois = []
|
| 51 |
+
|
| 52 |
+
self.preds.append({
|
| 53 |
+
'predictions': bboxes,
|
| 54 |
+
'hoi_prediction': hois
|
| 55 |
+
})
|
| 56 |
+
|
| 57 |
+
self.gts = []
|
| 58 |
+
for img_gts in gts:
|
| 59 |
+
img_gts = {k: v.to('cpu').numpy() for k, v in img_gts.items() if k != 'id'}
|
| 60 |
+
self.gts.append({
|
| 61 |
+
'annotations': [{'bbox': bbox, 'category_id': label} for bbox, label in zip(img_gts['boxes'], img_gts['labels'])],
|
| 62 |
+
'hoi_annotation': [{'subject_id': hoi[0], 'object_id': hoi[1], 'category_id': hoi[2]} for hoi in img_gts['hois']]
|
| 63 |
+
})
|
| 64 |
+
for hoi in self.gts[-1]['hoi_annotation']:
|
| 65 |
+
triplet = (self.gts[-1]['annotations'][hoi['subject_id']]['category_id'],
|
| 66 |
+
self.gts[-1]['annotations'][hoi['object_id']]['category_id'],
|
| 67 |
+
hoi['category_id'])
|
| 68 |
+
|
| 69 |
+
if triplet not in self.gt_triplets:
|
| 70 |
+
self.gt_triplets.append(triplet)
|
| 71 |
+
|
| 72 |
+
self.sum_gts[triplet] += 1
|
| 73 |
+
|
| 74 |
+
def evaluate(self):
|
| 75 |
+
for img_id, (img_preds, img_gts) in enumerate(zip(self.preds, self.gts)):
|
| 76 |
+
print(f"Evaluating Score Matrix... : [{(img_id+1):>4}/{len(self.gts):<4}]" ,flush=True, end="\r")
|
| 77 |
+
pred_bboxes = img_preds['predictions']
|
| 78 |
+
gt_bboxes = img_gts['annotations']
|
| 79 |
+
pred_hois = img_preds['hoi_prediction']
|
| 80 |
+
gt_hois = img_gts['hoi_annotation']
|
| 81 |
+
if len(gt_bboxes) != 0:
|
| 82 |
+
bbox_pairs, bbox_overlaps = self.compute_iou_mat(gt_bboxes, pred_bboxes)
|
| 83 |
+
self.compute_fptp(pred_hois, gt_hois, bbox_pairs, pred_bboxes, bbox_overlaps)
|
| 84 |
+
else:
|
| 85 |
+
for pred_hoi in pred_hois:
|
| 86 |
+
triplet = [pred_bboxes[pred_hoi['subject_id']]['category_id'],
|
| 87 |
+
pred_bboxes[pred_hoi['object_id']]['category_id'], pred_hoi['category_id']]
|
| 88 |
+
if triplet not in self.gt_triplets:
|
| 89 |
+
continue
|
| 90 |
+
self.tp[triplet].append(0)
|
| 91 |
+
self.fp[triplet].append(1)
|
| 92 |
+
self.score[triplet].append(pred_hoi['score'])
|
| 93 |
+
print(f"[stats] Score Matrix Generation completed!! ")
|
| 94 |
+
map = self.compute_map()
|
| 95 |
+
return map
|
| 96 |
+
|
| 97 |
+
def compute_map(self):
|
| 98 |
+
ap = defaultdict(lambda: 0)
|
| 99 |
+
rare_ap = defaultdict(lambda: 0)
|
| 100 |
+
non_rare_ap = defaultdict(lambda: 0)
|
| 101 |
+
max_recall = defaultdict(lambda: 0)
|
| 102 |
+
for triplet in self.gt_triplets:
|
| 103 |
+
sum_gts = self.sum_gts[triplet]
|
| 104 |
+
if sum_gts == 0:
|
| 105 |
+
continue
|
| 106 |
+
|
| 107 |
+
tp = np.array((self.tp[triplet]))
|
| 108 |
+
fp = np.array((self.fp[triplet]))
|
| 109 |
+
if len(tp) == 0:
|
| 110 |
+
ap[triplet] = 0
|
| 111 |
+
max_recall[triplet] = 0
|
| 112 |
+
if triplet in self.rare_triplets:
|
| 113 |
+
rare_ap[triplet] = 0
|
| 114 |
+
elif triplet in self.non_rare_triplets:
|
| 115 |
+
non_rare_ap[triplet] = 0
|
| 116 |
+
else:
|
| 117 |
+
print('Warning: triplet {} is neither in rare triplets nor in non-rare triplets'.format(triplet))
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
score = np.array(self.score[triplet])
|
| 121 |
+
sort_inds = np.argsort(-score)
|
| 122 |
+
fp = fp[sort_inds]
|
| 123 |
+
tp = tp[sort_inds]
|
| 124 |
+
fp = np.cumsum(fp)
|
| 125 |
+
tp = np.cumsum(tp)
|
| 126 |
+
rec = tp / sum_gts
|
| 127 |
+
prec = tp / (fp + tp)
|
| 128 |
+
ap[triplet] = self.voc_ap(rec, prec)
|
| 129 |
+
max_recall[triplet] = np.amax(rec)
|
| 130 |
+
if triplet in self.rare_triplets:
|
| 131 |
+
rare_ap[triplet] = ap[triplet]
|
| 132 |
+
elif triplet in self.non_rare_triplets:
|
| 133 |
+
non_rare_ap[triplet] = ap[triplet]
|
| 134 |
+
else:
|
| 135 |
+
print('Warning: triplet {} is neither in rare triplets nor in non-rare triplets'.format(triplet))
|
| 136 |
+
m_ap = np.mean(list(ap.values())) * 100 # percentage
|
| 137 |
+
m_ap_rare = np.mean(list(rare_ap.values())) * 100 # percentage
|
| 138 |
+
m_ap_non_rare = np.mean(list(non_rare_ap.values())) * 100 # percentage
|
| 139 |
+
m_max_recall = np.mean(list(max_recall.values()))
|
| 140 |
+
|
| 141 |
+
return {'mAP': m_ap, 'mAP rare': m_ap_rare, 'mAP non-rare': m_ap_non_rare, 'mean max recall': m_max_recall}
|
| 142 |
+
|
| 143 |
+
def voc_ap(self, rec, prec):
|
| 144 |
+
ap = 0.
|
| 145 |
+
for t in np.arange(0., 1.1, 0.1):
|
| 146 |
+
if np.sum(rec >= t) == 0:
|
| 147 |
+
p = 0
|
| 148 |
+
else:
|
| 149 |
+
p = np.max(prec[rec >= t])
|
| 150 |
+
ap = ap + p / 11.
|
| 151 |
+
return ap
|
| 152 |
+
|
| 153 |
+
def compute_fptp(self, pred_hois, gt_hois, match_pairs, pred_bboxes, bbox_overlaps):
|
| 154 |
+
pos_pred_ids = match_pairs.keys()
|
| 155 |
+
vis_tag = np.zeros(len(gt_hois))
|
| 156 |
+
pred_hois.sort(key=lambda k: (k.get('score', 0)), reverse=True)
|
| 157 |
+
if len(pred_hois) != 0:
|
| 158 |
+
for pred_hoi in pred_hois:
|
| 159 |
+
is_match = 0
|
| 160 |
+
if len(match_pairs) != 0 and pred_hoi['subject_id'] in pos_pred_ids and pred_hoi['object_id'] in pos_pred_ids:
|
| 161 |
+
pred_sub_ids = match_pairs[pred_hoi['subject_id']]
|
| 162 |
+
pred_obj_ids = match_pairs[pred_hoi['object_id']]
|
| 163 |
+
pred_sub_overlaps = bbox_overlaps[pred_hoi['subject_id']]
|
| 164 |
+
pred_obj_overlaps = bbox_overlaps[pred_hoi['object_id']]
|
| 165 |
+
pred_category_id = pred_hoi['category_id']
|
| 166 |
+
max_overlap = 0
|
| 167 |
+
max_gt_hoi = 0
|
| 168 |
+
for gt_hoi in gt_hois:
|
| 169 |
+
if gt_hoi['subject_id'] in pred_sub_ids and gt_hoi['object_id'] in pred_obj_ids \
|
| 170 |
+
and pred_category_id == gt_hoi['category_id']:
|
| 171 |
+
is_match = 1
|
| 172 |
+
min_overlap_gt = min(pred_sub_overlaps[pred_sub_ids.index(gt_hoi['subject_id'])],
|
| 173 |
+
pred_obj_overlaps[pred_obj_ids.index(gt_hoi['object_id'])])
|
| 174 |
+
if min_overlap_gt > max_overlap:
|
| 175 |
+
max_overlap = min_overlap_gt
|
| 176 |
+
max_gt_hoi = gt_hoi
|
| 177 |
+
triplet = (pred_bboxes[pred_hoi['subject_id']]['category_id'], pred_bboxes[pred_hoi['object_id']]['category_id'],
|
| 178 |
+
pred_hoi['category_id'])
|
| 179 |
+
if triplet not in self.gt_triplets:
|
| 180 |
+
continue
|
| 181 |
+
if is_match == 1 and vis_tag[gt_hois.index(max_gt_hoi)] == 0:
|
| 182 |
+
self.fp[triplet].append(0)
|
| 183 |
+
self.tp[triplet].append(1)
|
| 184 |
+
vis_tag[gt_hois.index(max_gt_hoi)] =1
|
| 185 |
+
else:
|
| 186 |
+
self.fp[triplet].append(1)
|
| 187 |
+
self.tp[triplet].append(0)
|
| 188 |
+
self.score[triplet].append(pred_hoi['score'])
|
| 189 |
+
|
| 190 |
+
def compute_iou_mat(self, bbox_list1, bbox_list2):
|
| 191 |
+
iou_mat = np.zeros((len(bbox_list1), len(bbox_list2)))
|
| 192 |
+
if len(bbox_list1) == 0 or len(bbox_list2) == 0:
|
| 193 |
+
return {}
|
| 194 |
+
for i, bbox1 in enumerate(bbox_list1):
|
| 195 |
+
for j, bbox2 in enumerate(bbox_list2):
|
| 196 |
+
iou_i = self.compute_IOU(bbox1, bbox2)
|
| 197 |
+
iou_mat[i, j] = iou_i
|
| 198 |
+
|
| 199 |
+
iou_mat_ov=iou_mat.copy()
|
| 200 |
+
iou_mat[iou_mat>=self.overlap_iou] = 1
|
| 201 |
+
iou_mat[iou_mat<self.overlap_iou] = 0
|
| 202 |
+
|
| 203 |
+
match_pairs = np.nonzero(iou_mat)
|
| 204 |
+
match_pairs_dict = {}
|
| 205 |
+
match_pair_overlaps = {}
|
| 206 |
+
if iou_mat.max() > 0:
|
| 207 |
+
for i, pred_id in enumerate(match_pairs[1]):
|
| 208 |
+
if pred_id not in match_pairs_dict.keys():
|
| 209 |
+
match_pairs_dict[pred_id] = []
|
| 210 |
+
match_pair_overlaps[pred_id]=[]
|
| 211 |
+
match_pairs_dict[pred_id].append(match_pairs[0][i])
|
| 212 |
+
match_pair_overlaps[pred_id].append(iou_mat_ov[match_pairs[0][i],pred_id])
|
| 213 |
+
return match_pairs_dict, match_pair_overlaps
|
| 214 |
+
|
| 215 |
+
def compute_IOU(self, bbox1, bbox2):
|
| 216 |
+
if isinstance(bbox1['category_id'], str):
|
| 217 |
+
bbox1['category_id'] = int(bbox1['category_id'].replace('\n', ''))
|
| 218 |
+
if isinstance(bbox2['category_id'], str):
|
| 219 |
+
bbox2['category_id'] = int(bbox2['category_id'].replace('\n', ''))
|
| 220 |
+
if bbox1['category_id'] == bbox2['category_id']:
|
| 221 |
+
rec1 = bbox1['bbox']
|
| 222 |
+
rec2 = bbox2['bbox']
|
| 223 |
+
# computing area of each rectangles
|
| 224 |
+
S_rec1 = (rec1[2] - rec1[0]+1) * (rec1[3] - rec1[1]+1)
|
| 225 |
+
S_rec2 = (rec2[2] - rec2[0]+1) * (rec2[3] - rec2[1]+1)
|
| 226 |
+
|
| 227 |
+
# computing the sum_area
|
| 228 |
+
sum_area = S_rec1 + S_rec2
|
| 229 |
+
|
| 230 |
+
# find the each edge of intersect rectangle
|
| 231 |
+
left_line = max(rec1[1], rec2[1])
|
| 232 |
+
right_line = min(rec1[3], rec2[3])
|
| 233 |
+
top_line = max(rec1[0], rec2[0])
|
| 234 |
+
bottom_line = min(rec1[2], rec2[2])
|
| 235 |
+
# judge if there is an intersect
|
| 236 |
+
if left_line >= right_line or top_line >= bottom_line:
|
| 237 |
+
return 0
|
| 238 |
+
else:
|
| 239 |
+
intersect = (right_line - left_line+1) * (bottom_line - top_line+1)
|
| 240 |
+
return intersect / (sum_area - intersect)
|
| 241 |
+
else:
|
| 242 |
+
return 0
|
hotr/data/evaluators/vcoco_eval.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) KakaoBrain, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
"""
|
| 3 |
+
V-COCO evaluator that works in distributed mode.
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from hotr.util.misc import all_gather
|
| 10 |
+
from hotr.metrics.vcoco.ap_role import APRole
|
| 11 |
+
from functools import partial
|
| 12 |
+
|
| 13 |
+
def init_vcoco_evaluators(human_act_name, object_act_name):
|
| 14 |
+
role_eval1 = APRole(act_name=object_act_name, scenario_flag=True, iou_threshold=0.5)
|
| 15 |
+
role_eval2 = APRole(act_name=object_act_name, scenario_flag=False, iou_threshold=0.5)
|
| 16 |
+
|
| 17 |
+
return role_eval1, role_eval2
|
| 18 |
+
|
| 19 |
+
class VCocoEvaluator(object):
|
| 20 |
+
def __init__(self, args):
|
| 21 |
+
self.img_ids = []
|
| 22 |
+
self.eval_imgs = []
|
| 23 |
+
self.role_eval1, self.role_eval2 = init_vcoco_evaluators(args.human_actions, args.object_actions)
|
| 24 |
+
self.num_human_act = args.num_human_act
|
| 25 |
+
self.action_idx = args.valid_ids
|
| 26 |
+
|
| 27 |
+
def update(self, outputs):
|
| 28 |
+
img_ids = list(np.unique(list(outputs.keys())))
|
| 29 |
+
for img_num, img_id in enumerate(img_ids):
|
| 30 |
+
print(f"Evaluating Score Matrix... : [{(img_num+1):>4}/{len(img_ids):<4}]" ,flush=True, end="\r")
|
| 31 |
+
prediction = outputs[img_id]['prediction']
|
| 32 |
+
target = outputs[img_id]['target']
|
| 33 |
+
|
| 34 |
+
# score with prediction
|
| 35 |
+
hbox, hcat, obox, ocat = list(map(lambda x: prediction[x], \
|
| 36 |
+
['h_box', 'h_cat', 'o_box', 'o_cat']))
|
| 37 |
+
|
| 38 |
+
assert 'pair_score' in prediction
|
| 39 |
+
score = prediction['pair_score']
|
| 40 |
+
|
| 41 |
+
hbox, hcat, obox, ocat, score =\
|
| 42 |
+
list(map(lambda x: x.cpu().numpy(), [hbox, hcat, obox, ocat, score]))
|
| 43 |
+
|
| 44 |
+
# ground-truth
|
| 45 |
+
gt_h_inds = (target['labels'] == 1)
|
| 46 |
+
gt_h_box = target['boxes'][gt_h_inds, :4].cpu().numpy()
|
| 47 |
+
gt_h_act = target['inst_actions'][gt_h_inds, :self.num_human_act].cpu().numpy()
|
| 48 |
+
|
| 49 |
+
gt_p_box = target['pair_boxes'].cpu().numpy()
|
| 50 |
+
gt_p_act = target['pair_actions'].cpu().numpy()
|
| 51 |
+
|
| 52 |
+
score = score[self.action_idx, :, :]
|
| 53 |
+
gt_p_act = gt_p_act[:, self.action_idx]
|
| 54 |
+
|
| 55 |
+
self.role_eval1.add_data(hbox, obox, score, gt_h_box, gt_h_act, gt_p_box, gt_p_act)
|
| 56 |
+
self.role_eval2.add_data(hbox, obox, score, gt_h_box, gt_h_act, gt_p_box, gt_p_act)
|
| 57 |
+
self.img_ids.append(img_id)
|
hotr/data/transforms/transforms.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
"""
|
| 3 |
+
Transforms and data augmentation for both image + bbox.
|
| 4 |
+
"""
|
| 5 |
+
import random
|
| 6 |
+
|
| 7 |
+
import PIL
|
| 8 |
+
import torch
|
| 9 |
+
import torchvision.transforms as T
|
| 10 |
+
import torchvision.transforms.functional as F
|
| 11 |
+
|
| 12 |
+
from hotr.util.box_ops import box_xyxy_to_cxcywh
|
| 13 |
+
from hotr.util.misc import interpolate
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def crop(image, target, region):
|
| 17 |
+
cropped_image = F.crop(image, *region)
|
| 18 |
+
|
| 19 |
+
target = target.copy()
|
| 20 |
+
i, j, h, w = region
|
| 21 |
+
|
| 22 |
+
# should we do something wrt the original size?
|
| 23 |
+
target["size"] = torch.tensor([h, w])
|
| 24 |
+
max_size = torch.as_tensor([w, h], dtype=torch.float32)
|
| 25 |
+
|
| 26 |
+
fields = ["labels", "area", "iscrowd"] # add additional fields
|
| 27 |
+
if "inst_actions" in target.keys():
|
| 28 |
+
fields.append("inst_actions")
|
| 29 |
+
|
| 30 |
+
if "boxes" in target:
|
| 31 |
+
boxes = target["boxes"]
|
| 32 |
+
cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
|
| 33 |
+
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
|
| 34 |
+
cropped_boxes = cropped_boxes.clamp(min=0)
|
| 35 |
+
area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
|
| 36 |
+
target["boxes"] = cropped_boxes.reshape(-1, 4)
|
| 37 |
+
target["area"] = area
|
| 38 |
+
fields.append("boxes")
|
| 39 |
+
|
| 40 |
+
if "pair_boxes" in target or ("sub_boxes" in target and "obj_boxes" in target):
|
| 41 |
+
if "pair_boxes" in target:
|
| 42 |
+
pair_boxes = target["pair_boxes"]
|
| 43 |
+
hboxes = pair_boxes[:, :4]
|
| 44 |
+
oboxes = pair_boxes[:, 4:]
|
| 45 |
+
if ("sub_boxes" in target and "obj_boxes" in target):
|
| 46 |
+
hboxes = target["sub_boxes"]
|
| 47 |
+
oboxes = target["obj_boxes"]
|
| 48 |
+
|
| 49 |
+
cropped_hboxes = hboxes - torch.as_tensor([j, i, j, i])
|
| 50 |
+
cropped_hboxes = torch.min(cropped_hboxes.reshape(-1, 2, 2), max_size)
|
| 51 |
+
cropped_hboxes = cropped_hboxes.clamp(min=0)
|
| 52 |
+
hboxes = cropped_hboxes.reshape(-1, 4)
|
| 53 |
+
|
| 54 |
+
obj_mask = (oboxes[:, 0] != -1)
|
| 55 |
+
if obj_mask.sum() != 0:
|
| 56 |
+
cropped_oboxes = oboxes[obj_mask] - torch.as_tensor([j, i, j, i])
|
| 57 |
+
cropped_oboxes = torch.min(cropped_oboxes.reshape(-1, 2, 2), max_size)
|
| 58 |
+
cropped_oboxes = cropped_oboxes.clamp(min=0)
|
| 59 |
+
oboxes[obj_mask] = cropped_oboxes.reshape(-1, 4)
|
| 60 |
+
else:
|
| 61 |
+
cropped_oboxes = oboxes
|
| 62 |
+
|
| 63 |
+
cropped_pair_boxes = torch.cat([hboxes, oboxes], dim=-1)
|
| 64 |
+
target["pair_boxes"] = cropped_pair_boxes
|
| 65 |
+
pair_fields = ["pair_boxes", "pair_actions", "pair_targets"]
|
| 66 |
+
|
| 67 |
+
if "masks" in target:
|
| 68 |
+
# FIXME should we update the area here if there are no boxes[?
|
| 69 |
+
target['masks'] = target['masks'][:, i:i + h, j:j + w]
|
| 70 |
+
fields.append("masks")
|
| 71 |
+
|
| 72 |
+
# remove elements for which the boxes or masks that have zero area
|
| 73 |
+
if "boxes" in target or "masks" in target:
|
| 74 |
+
# favor boxes selection when defining which elements to keep
|
| 75 |
+
# this is compatible with previous implementation
|
| 76 |
+
if "boxes" in target:
|
| 77 |
+
cropped_boxes = target['boxes'].reshape(-1, 2, 2)
|
| 78 |
+
keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
|
| 79 |
+
else:
|
| 80 |
+
keep = target['masks'].flatten(1).any(1)
|
| 81 |
+
|
| 82 |
+
for field in fields:
|
| 83 |
+
if field in target: # added this because there is no 'iscrowd' field in v-coco dataset
|
| 84 |
+
target[field] = target[field][keep]
|
| 85 |
+
|
| 86 |
+
# remove elements that have redundant area
|
| 87 |
+
if "boxes" in target and "labels" in target:
|
| 88 |
+
cropped_boxes = target['boxes']
|
| 89 |
+
cropped_labels = target['labels']
|
| 90 |
+
|
| 91 |
+
cnr, keep_idx = [], []
|
| 92 |
+
for idx, (cropped_box, cropped_lbl) in enumerate(zip(cropped_boxes, cropped_labels)):
|
| 93 |
+
if str((cropped_box, cropped_lbl)) not in cnr:
|
| 94 |
+
cnr.append(str((cropped_box, cropped_lbl)))
|
| 95 |
+
keep_idx.append(True)
|
| 96 |
+
else: keep_idx.append(False)
|
| 97 |
+
|
| 98 |
+
for field in fields:
|
| 99 |
+
if field in target:
|
| 100 |
+
target[field] = target[field][keep_idx]
|
| 101 |
+
|
| 102 |
+
# remove elements for which pair boxes have zero area
|
| 103 |
+
if "pair_boxes" in target:
|
| 104 |
+
cropped_hboxes = target["pair_boxes"][:, :4].reshape(-1, 2, 2)
|
| 105 |
+
cropped_oboxes = target["pair_boxes"][:, 4:].reshape(-1, 2, 2)
|
| 106 |
+
keep_h = torch.all(cropped_hboxes[:, 1, :] > cropped_hboxes[:, 0, :], dim=1)
|
| 107 |
+
keep_o = torch.all(cropped_oboxes[:, 1, :] > cropped_oboxes[:, 0, :], dim=1)
|
| 108 |
+
not_empty_o = torch.all(target["pair_boxes"][:, 4:] >= 0, dim=1)
|
| 109 |
+
discard_o = (~keep_o) & not_empty_o
|
| 110 |
+
if (discard_o).sum() > 0:
|
| 111 |
+
target["pair_boxes"][discard_o, 4:] = -1
|
| 112 |
+
|
| 113 |
+
for pair_field in pair_fields:
|
| 114 |
+
target[pair_field] = target[pair_field][keep_h]
|
| 115 |
+
|
| 116 |
+
return cropped_image, target
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def hflip(image, target):
|
| 120 |
+
flipped_image = F.hflip(image)
|
| 121 |
+
|
| 122 |
+
w, h = image.size
|
| 123 |
+
|
| 124 |
+
target = target.copy()
|
| 125 |
+
if "boxes" in target:
|
| 126 |
+
boxes = target["boxes"]
|
| 127 |
+
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
|
| 128 |
+
target["boxes"] = boxes
|
| 129 |
+
|
| 130 |
+
if "pair_boxes" in target:
|
| 131 |
+
pair_boxes = target["pair_boxes"]
|
| 132 |
+
hboxes = pair_boxes[:, :4]
|
| 133 |
+
oboxes = pair_boxes[:, 4:]
|
| 134 |
+
|
| 135 |
+
# human flip
|
| 136 |
+
hboxes = hboxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
|
| 137 |
+
|
| 138 |
+
# object flip
|
| 139 |
+
obj_mask = (oboxes[:, 0] != -1)
|
| 140 |
+
if obj_mask.sum() != 0:
|
| 141 |
+
o_tmp = oboxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
|
| 142 |
+
oboxes[obj_mask] = o_tmp[obj_mask]
|
| 143 |
+
|
| 144 |
+
pair_boxes = torch.cat([hboxes, oboxes], dim=-1)
|
| 145 |
+
target["pair_boxes"] = pair_boxes
|
| 146 |
+
|
| 147 |
+
if "masks" in target:
|
| 148 |
+
target['masks'] = target['masks'].flip(-1)
|
| 149 |
+
|
| 150 |
+
return flipped_image, target
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def resize(image, target, size, max_size=None):
|
| 154 |
+
# size can be min_size (scalar) or (w, h) tuple
|
| 155 |
+
|
| 156 |
+
def get_size_with_aspect_ratio(image_size, size, max_size=None):
|
| 157 |
+
w, h = image_size
|
| 158 |
+
if max_size is not None:
|
| 159 |
+
min_original_size = float(min((w, h)))
|
| 160 |
+
max_original_size = float(max((w, h)))
|
| 161 |
+
if max_original_size / min_original_size * size > max_size:
|
| 162 |
+
size = int(round(max_size * min_original_size / max_original_size))
|
| 163 |
+
|
| 164 |
+
if (w <= h and w == size) or (h <= w and h == size):
|
| 165 |
+
return (h, w)
|
| 166 |
+
|
| 167 |
+
if w < h:
|
| 168 |
+
ow = size
|
| 169 |
+
oh = int(size * h / w)
|
| 170 |
+
else:
|
| 171 |
+
oh = size
|
| 172 |
+
ow = int(size * w / h)
|
| 173 |
+
|
| 174 |
+
return (oh, ow)
|
| 175 |
+
|
| 176 |
+
def get_size(image_size, size, max_size=None):
|
| 177 |
+
if isinstance(size, (list, tuple)):
|
| 178 |
+
return size[::-1]
|
| 179 |
+
else:
|
| 180 |
+
return get_size_with_aspect_ratio(image_size, size, max_size)
|
| 181 |
+
|
| 182 |
+
size = get_size(image.size, size, max_size)
|
| 183 |
+
rescaled_image = F.resize(image, size)
|
| 184 |
+
|
| 185 |
+
if target is None:
|
| 186 |
+
return rescaled_image, None
|
| 187 |
+
|
| 188 |
+
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
|
| 189 |
+
ratio_width, ratio_height = ratios
|
| 190 |
+
|
| 191 |
+
target = target.copy()
|
| 192 |
+
if "boxes" in target:
|
| 193 |
+
boxes = target["boxes"]
|
| 194 |
+
scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
|
| 195 |
+
target["boxes"] = scaled_boxes
|
| 196 |
+
|
| 197 |
+
if "pair_boxes" in target:
|
| 198 |
+
hboxes = target["pair_boxes"][:, :4]
|
| 199 |
+
scaled_hboxes = hboxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
|
| 200 |
+
hboxes = scaled_hboxes
|
| 201 |
+
|
| 202 |
+
oboxes = target["pair_boxes"][:, 4:]
|
| 203 |
+
obj_mask = (oboxes[:, 0] != -1)
|
| 204 |
+
if obj_mask.sum() != 0:
|
| 205 |
+
scaled_oboxes = oboxes[obj_mask] * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
|
| 206 |
+
oboxes[obj_mask] = scaled_oboxes
|
| 207 |
+
|
| 208 |
+
target["pair_boxes"] = torch.cat([hboxes, oboxes], dim=-1)
|
| 209 |
+
|
| 210 |
+
if "area" in target:
|
| 211 |
+
area = target["area"]
|
| 212 |
+
scaled_area = area * (ratio_width * ratio_height)
|
| 213 |
+
target["area"] = scaled_area
|
| 214 |
+
|
| 215 |
+
h, w = size
|
| 216 |
+
target["size"] = torch.tensor([h, w])
|
| 217 |
+
|
| 218 |
+
if "masks" in target:
|
| 219 |
+
target['masks'] = interpolate(
|
| 220 |
+
target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5
|
| 221 |
+
|
| 222 |
+
return rescaled_image, target
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def pad(image, target, padding):
|
| 226 |
+
# assumes that we only pad on the bottom right corners
|
| 227 |
+
padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
|
| 228 |
+
if target is None:
|
| 229 |
+
return padded_image, None
|
| 230 |
+
target = target.copy()
|
| 231 |
+
# should we do something wrt the original size?
|
| 232 |
+
target["size"] = torch.tensor(padded_image[::-1])
|
| 233 |
+
if "masks" in target:
|
| 234 |
+
target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1]))
|
| 235 |
+
return padded_image, target
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class RandomCrop(object):
|
| 239 |
+
def __init__(self, size):
|
| 240 |
+
self.size = size
|
| 241 |
+
|
| 242 |
+
def __call__(self, img, target):
|
| 243 |
+
region = T.RandomCrop.get_params(img, self.size)
|
| 244 |
+
return crop(img, target, region)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class RandomSizeCrop(object):
|
| 248 |
+
def __init__(self, min_size: int, max_size: int):
|
| 249 |
+
self.min_size = min_size
|
| 250 |
+
self.max_size = max_size
|
| 251 |
+
|
| 252 |
+
def __call__(self, img: PIL.Image.Image, target: dict):
|
| 253 |
+
w = random.randint(self.min_size, min(img.width, self.max_size))
|
| 254 |
+
h = random.randint(self.min_size, min(img.height, self.max_size))
|
| 255 |
+
region = T.RandomCrop.get_params(img, [h, w])
|
| 256 |
+
return crop(img, target, region)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class CenterCrop(object):
|
| 260 |
+
def __init__(self, size):
|
| 261 |
+
self.size = size
|
| 262 |
+
|
| 263 |
+
def __call__(self, img, target):
|
| 264 |
+
image_width, image_height = img.size
|
| 265 |
+
crop_height, crop_width = self.size
|
| 266 |
+
crop_top = int(round((image_height - crop_height) / 2.))
|
| 267 |
+
crop_left = int(round((image_width - crop_width) / 2.))
|
| 268 |
+
return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class RandomHorizontalFlip(object):
|
| 272 |
+
def __init__(self, p=0.5):
|
| 273 |
+
self.p = p
|
| 274 |
+
|
| 275 |
+
def __call__(self, img, target):
|
| 276 |
+
if random.random() < self.p:
|
| 277 |
+
return hflip(img, target)
|
| 278 |
+
return img, target
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class RandomResize(object):
|
| 282 |
+
def __init__(self, sizes, max_size=None):
|
| 283 |
+
assert isinstance(sizes, (list, tuple))
|
| 284 |
+
self.sizes = sizes
|
| 285 |
+
self.max_size = max_size
|
| 286 |
+
|
| 287 |
+
def __call__(self, img, target=None):
|
| 288 |
+
size = random.choice(self.sizes)
|
| 289 |
+
return resize(img, target, size, self.max_size)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class RandomPad(object):
|
| 293 |
+
def __init__(self, max_pad):
|
| 294 |
+
self.max_pad = max_pad
|
| 295 |
+
|
| 296 |
+
def __call__(self, img, target):
|
| 297 |
+
pad_x = random.randint(0, self.max_pad)
|
| 298 |
+
pad_y = random.randint(0, self.max_pad)
|
| 299 |
+
return pad(img, target, (pad_x, pad_y))
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class RandomSelect(object):
|
| 303 |
+
"""
|
| 304 |
+
Randomly selects between transforms1 and transforms2,
|
| 305 |
+
with probability p for transforms1 and (1 - p) for transforms2
|
| 306 |
+
"""
|
| 307 |
+
def __init__(self, transforms1, transforms2, p=0.5):
|
| 308 |
+
self.transforms1 = transforms1
|
| 309 |
+
self.transforms2 = transforms2
|
| 310 |
+
self.p = p
|
| 311 |
+
|
| 312 |
+
def __call__(self, img, target):
|
| 313 |
+
if random.random() < self.p:
|
| 314 |
+
return self.transforms1(img, target)
|
| 315 |
+
return self.transforms2(img, target)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class ToTensor(object):
|
| 319 |
+
def __call__(self, img, target):
|
| 320 |
+
return F.to_tensor(img), target
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class RandomErasing(object):
|
| 324 |
+
|
| 325 |
+
def __init__(self, *args, **kwargs):
|
| 326 |
+
self.eraser = T.RandomErasing(*args, **kwargs)
|
| 327 |
+
|
| 328 |
+
def __call__(self, img, target):
|
| 329 |
+
return self.eraser(img), target
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
class Normalize(object):
|
| 333 |
+
def __init__(self, mean, std):
|
| 334 |
+
self.mean = mean
|
| 335 |
+
self.std = std
|
| 336 |
+
|
| 337 |
+
def __call__(self, image, target=None):
|
| 338 |
+
image = F.normalize(image, mean=self.mean, std=self.std)
|
| 339 |
+
if target is None:
|
| 340 |
+
return image, None
|
| 341 |
+
target = target.copy()
|
| 342 |
+
h, w = image.shape[-2:]
|
| 343 |
+
if "boxes" in target:
|
| 344 |
+
boxes = target["boxes"]
|
| 345 |
+
boxes = box_xyxy_to_cxcywh(boxes)
|
| 346 |
+
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
|
| 347 |
+
target["boxes"] = boxes
|
| 348 |
+
|
| 349 |
+
if "pair_boxes" in target:
|
| 350 |
+
hboxes = target["pair_boxes"][:, :4]
|
| 351 |
+
hboxes = box_xyxy_to_cxcywh(hboxes)
|
| 352 |
+
hboxes = hboxes / torch.tensor([w, h, w, h], dtype=torch.float32)
|
| 353 |
+
|
| 354 |
+
oboxes = target["pair_boxes"][:, 4:]
|
| 355 |
+
obj_mask = (oboxes[:, 0] != -1)
|
| 356 |
+
if obj_mask.sum() != 0:
|
| 357 |
+
oboxes[obj_mask] = box_xyxy_to_cxcywh(oboxes[obj_mask])
|
| 358 |
+
oboxes[obj_mask] = oboxes[obj_mask] / torch.tensor([w, h, w, h], dtype=torch.float32)
|
| 359 |
+
|
| 360 |
+
pair_boxes = torch.cat([hboxes, oboxes], dim=-1)
|
| 361 |
+
target["pair_boxes"] = pair_boxes
|
| 362 |
+
|
| 363 |
+
return image, target
|
| 364 |
+
|
| 365 |
+
class ColorJitter(object):
|
| 366 |
+
def __init__(self, brightness=0, contrast=0, saturatio=0, hue=0):
|
| 367 |
+
self.color_jitter = T.ColorJitter(brightness, contrast, saturatio, hue)
|
| 368 |
+
|
| 369 |
+
def __call__(self, img, target):
|
| 370 |
+
return self.color_jitter(img), target
|
| 371 |
+
|
| 372 |
+
class Compose(object):
|
| 373 |
+
def __init__(self, transforms):
|
| 374 |
+
self.transforms = transforms
|
| 375 |
+
|
| 376 |
+
def __call__(self, image, target):
|
| 377 |
+
for t in self.transforms:
|
| 378 |
+
image, target = t(image, target)
|
| 379 |
+
return image, target
|
| 380 |
+
|
| 381 |
+
def __repr__(self):
|
| 382 |
+
format_string = self.__class__.__name__ + "("
|
| 383 |
+
for t in self.transforms:
|
| 384 |
+
format_string += "\n"
|
| 385 |
+
format_string += " {0}".format(t)
|
| 386 |
+
format_string += "\n)"
|
| 387 |
+
return format_string
|
hotr/engine/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .evaluator_vcoco import vcoco_evaluate, vcoco_accumulate
|
| 2 |
+
from .evaluator_hico import hico_evaluate
|
| 3 |
+
|
| 4 |
+
def hoi_evaluator(args, model, criterion, postprocessors, data_loader, device, thr=0):
|
| 5 |
+
if args.dataset_file == 'vcoco':
|
| 6 |
+
return vcoco_evaluate(model, criterion, postprocessors, data_loader, device, args.output_dir, thr,args=args)
|
| 7 |
+
elif args.dataset_file == 'hico-det':
|
| 8 |
+
return hico_evaluate(model, postprocessors, data_loader, device, thr,args=args)
|
| 9 |
+
else: raise NotImplementedError
|
| 10 |
+
|
| 11 |
+
def hoi_accumulator(args, total_res, print_results=False, wandb=False):
|
| 12 |
+
if args.dataset_file == 'vcoco':
|
| 13 |
+
return vcoco_accumulate(total_res, args, print_results, wandb)
|
| 14 |
+
else: raise NotImplementedError
|
hotr/engine/arg_parser.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# HOTR official code : engine/arg_parser.py
|
| 3 |
+
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
|
| 4 |
+
# Modified arguments are represented with *
|
| 5 |
+
# ------------------------------------------------------------------------
|
| 6 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 7 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 8 |
+
# ------------------------------------------------------------------------
|
| 9 |
+
import argparse
|
| 10 |
+
import hotr.util.misc as utils
|
| 11 |
+
|
| 12 |
+
def get_args_parser():
|
| 13 |
+
parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
|
| 14 |
+
parser.add_argument('--lr', default=1e-4, type=float)
|
| 15 |
+
parser.add_argument('--lr_backbone', default=1e-5, type=float)
|
| 16 |
+
parser.add_argument('--batch_size', default=2, type=int)
|
| 17 |
+
parser.add_argument('--weight_decay', default=1e-4, type=float)
|
| 18 |
+
parser.add_argument('--epochs', default=100, type=int)
|
| 19 |
+
parser.add_argument('--lr_drop', default=80, type=int)
|
| 20 |
+
parser.add_argument('--clip_max_norm', default=0.1, type=float,
|
| 21 |
+
help='gradient clipping max norm')
|
| 22 |
+
|
| 23 |
+
# DETR Model parameters
|
| 24 |
+
parser.add_argument('--frozen_weights', type=str, default=None,
|
| 25 |
+
help="Path to the pretrained model. If set, only the mask head will be trained")
|
| 26 |
+
parser.add_argument('--pretrain_interaction_tf', type=str, default=None,
|
| 27 |
+
help="Path to the pretrained model. If set, only the mask head will be trained")
|
| 28 |
+
|
| 29 |
+
# DETR Backbone
|
| 30 |
+
parser.add_argument('--backbone', default='resnet50', type=str,
|
| 31 |
+
help="Name of the convolutional backbone to use")
|
| 32 |
+
parser.add_argument('--dilation', action='store_true',
|
| 33 |
+
help="If true, we replace stride with dilation in the last convolutional block (DC5)")
|
| 34 |
+
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
|
| 35 |
+
help="Type of positional embedding to use on top of the image features")
|
| 36 |
+
|
| 37 |
+
# DETR Transformer (= Encoder, Instance Decoder)
|
| 38 |
+
parser.add_argument('--enc_layers', default=6, type=int,
|
| 39 |
+
help="Number of encoding layers in the transformer")
|
| 40 |
+
parser.add_argument('--dec_layers', default=6, type=int,
|
| 41 |
+
help="Number of decoding layers in the transformer")
|
| 42 |
+
parser.add_argument('--dim_feedforward', default=2048, type=int,
|
| 43 |
+
help="Intermediate size of the feedforward layers in the transformer blocks")
|
| 44 |
+
parser.add_argument('--hidden_dim', default=256, type=int,
|
| 45 |
+
help="Size of the embeddings (dimension of the transformer)")
|
| 46 |
+
parser.add_argument('--dropout', default=0.1, type=float,
|
| 47 |
+
help="Dropout applied in the transformer")
|
| 48 |
+
parser.add_argument('--nheads', default=8, type=int,
|
| 49 |
+
help="Number of attention heads inside the transformer's attentions")
|
| 50 |
+
parser.add_argument('--num_queries', default=100, type=int,
|
| 51 |
+
help="Number of query slots")
|
| 52 |
+
parser.add_argument('--pre_norm', action='store_true')
|
| 53 |
+
parser.add_argument('--decoder_form', default=2, type=int,
|
| 54 |
+
help="1-decoder or 2-decoder")
|
| 55 |
+
# Segmentation
|
| 56 |
+
parser.add_argument('--masks', action='store_true',
|
| 57 |
+
help="Train segmentation head if the flag is provided")
|
| 58 |
+
|
| 59 |
+
# Loss Option
|
| 60 |
+
parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
|
| 61 |
+
help="Disables auxiliary decoding losses (loss at each layer)")
|
| 62 |
+
|
| 63 |
+
# Loss coefficients (DETR)
|
| 64 |
+
parser.add_argument('--mask_loss_coef', default=1, type=float)
|
| 65 |
+
parser.add_argument('--dice_loss_coef', default=1, type=float)
|
| 66 |
+
parser.add_argument('--bbox_loss_coef', default=5, type=float)
|
| 67 |
+
parser.add_argument('--giou_loss_coef', default=2, type=float)
|
| 68 |
+
parser.add_argument('--eos_coef', default=0.1, type=float,
|
| 69 |
+
help="Relative classification weight of the no-object class")
|
| 70 |
+
|
| 71 |
+
# Matcher (DETR)
|
| 72 |
+
parser.add_argument('--set_cost_class', default=1, type=float,
|
| 73 |
+
help="Class coefficient in the matching cost")
|
| 74 |
+
parser.add_argument('--set_cost_bbox', default=5, type=float,
|
| 75 |
+
help="L1 box coefficient in the matching cost")
|
| 76 |
+
parser.add_argument('--set_cost_giou', default=2, type=float,
|
| 77 |
+
help="giou box coefficient in the matching cost")
|
| 78 |
+
|
| 79 |
+
# * HOI Detection
|
| 80 |
+
parser.add_argument('--HOIDet', action='store_true',
|
| 81 |
+
help="Train HOI Detection head if the flag is provided")
|
| 82 |
+
parser.add_argument('--share_enc', action='store_true',
|
| 83 |
+
help="Share the Encoder in DETR for HOI Detection if the flag is provided")
|
| 84 |
+
parser.add_argument('--pretrained_dec', action='store_true',
|
| 85 |
+
help="Use Pre-trained Decoder in DETR for Interaction Decoder if the flag is provided")
|
| 86 |
+
parser.add_argument('--hoi_enc_layers', default=1, type=int,
|
| 87 |
+
help="Number of decoding layers in HOI transformer")
|
| 88 |
+
parser.add_argument('--hoi_dec_layers', default=1, type=int,
|
| 89 |
+
help="Number of decoding layers in HOI transformer")
|
| 90 |
+
parser.add_argument('--hoi_nheads', default=8, type=int,
|
| 91 |
+
help="Number of decoding layers in HOI transformer")
|
| 92 |
+
parser.add_argument('--hoi_dim_feedforward', default=2048, type=int,
|
| 93 |
+
help="Number of decoding layers in HOI transformer")
|
| 94 |
+
# parser.add_argument('--hoi_mode', type=str, default=None, help='[inst | pair | all]')
|
| 95 |
+
parser.add_argument('--num_hoi_queries', default=100, type=int,
|
| 96 |
+
help="Number of Queries for Interaction Decoder")
|
| 97 |
+
parser.add_argument('--hoi_aux_loss', action='store_true')
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# * HOTR Matcher
|
| 101 |
+
parser.add_argument('--set_cost_idx', default=1, type=float,
|
| 102 |
+
help="IDX coefficient in the matching cost")
|
| 103 |
+
parser.add_argument('--set_cost_act', default=1, type=float,
|
| 104 |
+
help="Action coefficient in the matching cost")
|
| 105 |
+
parser.add_argument('--set_cost_tgt', default=1, type=float,
|
| 106 |
+
help="Target coefficient in the matching cost")
|
| 107 |
+
|
| 108 |
+
# * HOTR Loss coefficients
|
| 109 |
+
parser.add_argument('--temperature', default=0.05, type=float, help="temperature")
|
| 110 |
+
parser.add_argument('--hoi_consistency_loss_coef', default=1, type=float)
|
| 111 |
+
parser.add_argument('--hoi_idx_loss_coef', default=1, type=float)
|
| 112 |
+
parser.add_argument('--hoi_idx_consistency_loss_coef', default=1, type=float)
|
| 113 |
+
parser.add_argument('--hoi_act_loss_coef', default=1, type=float)
|
| 114 |
+
parser.add_argument('--hoi_act_consistency_loss_coef', default=1, type=float)
|
| 115 |
+
parser.add_argument('--hoi_tgt_loss_coef', default=1, type=float)
|
| 116 |
+
parser.add_argument('--hoi_tgt_consistency_loss_coef', default=1, type=float)
|
| 117 |
+
parser.add_argument('--hoi_eos_coef', default=0.1, type=float, help="Relative classification weight of the no-object class")
|
| 118 |
+
|
| 119 |
+
parser.add_argument('--ramp_down_epoch',default=10000,type=int)
|
| 120 |
+
parser.add_argument('--ramp_up_epoch',default=0,type=int)
|
| 121 |
+
#consistency
|
| 122 |
+
parser.add_argument('--use_consis',action='store_true',help='use consistency regularization')
|
| 123 |
+
parser.add_argument('--share_dec_param',action='store_true',help = 'share decoder parameters of all stages')
|
| 124 |
+
parser.add_argument("--augpath_name", type=utils.arg_as_list,default=[],
|
| 125 |
+
help='choose which augmented inference paths to use. (p2:x->HO->I,p3:x->HI->O,p4:x->OI->H)')
|
| 126 |
+
parser.add_argument('--stop_grad_stage',action='store_true',help='Do not back propogate loss to previous stage')
|
| 127 |
+
parser.add_argument('--path_id', default=0, type=int)
|
| 128 |
+
|
| 129 |
+
parser.add_argument('--sep_enc_forward',action='store_true')
|
| 130 |
+
|
| 131 |
+
# * dataset parameters
|
| 132 |
+
parser.add_argument('--dataset_file', help='[coco | vcoco]')
|
| 133 |
+
parser.add_argument('--data_path', type=str)
|
| 134 |
+
parser.add_argument('--object_threshold', type=float, default=0, help='Threshold for object confidence')
|
| 135 |
+
|
| 136 |
+
# machine parameters
|
| 137 |
+
parser.add_argument('--output_dir', default='',
|
| 138 |
+
help='path where to save, empty for no saving')
|
| 139 |
+
parser.add_argument('--custom_path', default='',
|
| 140 |
+
help="Data path for custom inference. Only required for custom_main.py")
|
| 141 |
+
parser.add_argument('--device', default='cuda',
|
| 142 |
+
help='device to use for training / testing')
|
| 143 |
+
parser.add_argument('--seed', default=42, type=int)
|
| 144 |
+
parser.add_argument('--resume', default='', help='resume from checkpoint')
|
| 145 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
| 146 |
+
help='start epoch')
|
| 147 |
+
parser.add_argument('--num_workers', default=2, type=int)
|
| 148 |
+
|
| 149 |
+
# mode
|
| 150 |
+
parser.add_argument('--eval', action='store_true', help="Only evaluate results if the flag is provided")
|
| 151 |
+
parser.add_argument('--validate', action='store_true', help="Validate after every epoch")
|
| 152 |
+
|
| 153 |
+
# distributed training parameters
|
| 154 |
+
parser.add_argument('--world_size', default=1, type=int,
|
| 155 |
+
help='number of distributed processes')
|
| 156 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
| 157 |
+
|
| 158 |
+
# * WanDB
|
| 159 |
+
parser.add_argument('--wandb', action='store_true')
|
| 160 |
+
parser.add_argument('--project_name', default='hotr_cpc')
|
| 161 |
+
parser.add_argument('--group_name', default='mlv')
|
| 162 |
+
parser.add_argument('--run_name', default='run_000001')
|
| 163 |
+
return parser
|
hotr/engine/evaluator_coco.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import hotr.util.misc as utils
|
| 4 |
+
import hotr.util.logger as loggers
|
| 5 |
+
from hotr.data.evaluators.coco_eval import CocoEvaluator
|
| 6 |
+
|
| 7 |
+
@torch.no_grad()
|
| 8 |
+
def coco_evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir):
|
| 9 |
+
model.eval()
|
| 10 |
+
criterion.eval()
|
| 11 |
+
|
| 12 |
+
metric_logger = loggers.MetricLogger(delimiter=" ")
|
| 13 |
+
metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
|
| 14 |
+
header = 'Evaluation'
|
| 15 |
+
|
| 16 |
+
iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys())
|
| 17 |
+
coco_evaluator = CocoEvaluator(base_ds, iou_types)
|
| 18 |
+
print_freq = len(data_loader)
|
| 19 |
+
# coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75]
|
| 20 |
+
|
| 21 |
+
print("\n>>> [MS-COCO Evaluation] <<<")
|
| 22 |
+
for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
|
| 23 |
+
samples = samples.to(device)
|
| 24 |
+
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
| 25 |
+
|
| 26 |
+
outputs = model(samples)
|
| 27 |
+
loss_dict = criterion(outputs, targets)
|
| 28 |
+
weight_dict = criterion.weight_dict
|
| 29 |
+
|
| 30 |
+
# reduce losses over all GPUs for logging purposes
|
| 31 |
+
loss_dict_reduced = utils.reduce_dict(loss_dict)
|
| 32 |
+
loss_dict_reduced_scaled = {k: v * weight_dict[k]
|
| 33 |
+
for k, v in loss_dict_reduced.items() if k in weight_dict}
|
| 34 |
+
loss_dict_reduced_unscaled = {f'{k}_unscaled': v
|
| 35 |
+
for k, v in loss_dict_reduced.items()}
|
| 36 |
+
metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
|
| 37 |
+
**loss_dict_reduced_scaled,
|
| 38 |
+
**loss_dict_reduced_unscaled)
|
| 39 |
+
metric_logger.update(class_error=loss_dict_reduced['class_error'])
|
| 40 |
+
|
| 41 |
+
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
|
| 42 |
+
results = postprocessors['bbox'](outputs, orig_target_sizes)
|
| 43 |
+
res = {target['image_id'].item(): output for target, output in zip(targets, results)}
|
| 44 |
+
if coco_evaluator is not None:
|
| 45 |
+
coco_evaluator.update(res)
|
| 46 |
+
|
| 47 |
+
# gather the stats from all processes
|
| 48 |
+
metric_logger.synchronize_between_processes()
|
| 49 |
+
print("\n>>> [Averaged stats] <<<\n", metric_logger)
|
| 50 |
+
if coco_evaluator is not None:
|
| 51 |
+
coco_evaluator.synchronize_between_processes()
|
| 52 |
+
|
| 53 |
+
# accumulate predictions from all images
|
| 54 |
+
if coco_evaluator is not None:
|
| 55 |
+
coco_evaluator.accumulate()
|
| 56 |
+
coco_evaluator.summarize()
|
| 57 |
+
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
| 58 |
+
if coco_evaluator is not None:
|
| 59 |
+
if 'bbox' in postprocessors.keys():
|
| 60 |
+
stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
|
| 61 |
+
|
| 62 |
+
return stats, coco_evaluator
|
hotr/engine/evaluator_hico.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
from typing import Iterable
|
| 5 |
+
import numpy as np
|
| 6 |
+
import copy
|
| 7 |
+
import itertools
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
import hotr.util.misc as utils
|
| 12 |
+
import hotr.util.logger as loggers
|
| 13 |
+
from hotr.data.evaluators.hico_eval import HICOEvaluator
|
| 14 |
+
|
| 15 |
+
@torch.no_grad()
|
| 16 |
+
def hico_evaluate(model, postprocessors, data_loader, device, thr, args=None):
|
| 17 |
+
model.eval()
|
| 18 |
+
|
| 19 |
+
metric_logger = loggers.MetricLogger(mode="test", delimiter=" ")
|
| 20 |
+
header = 'Evaluation Inference (HICO-DET)'
|
| 21 |
+
|
| 22 |
+
preds = []
|
| 23 |
+
gts = []
|
| 24 |
+
indices = []
|
| 25 |
+
hoi_recognition_time = []
|
| 26 |
+
|
| 27 |
+
for samples, targets in metric_logger.log_every(data_loader, 50, header):
|
| 28 |
+
samples = samples.to(device)
|
| 29 |
+
targets = [{k: (v.to(device) if k != 'id' else v) for k, v in t.items()} for t in targets]
|
| 30 |
+
|
| 31 |
+
outputs = model(samples)
|
| 32 |
+
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
|
| 33 |
+
results = postprocessors['hoi'](outputs, orig_target_sizes, threshold=thr, dataset='hico-det', args=args)
|
| 34 |
+
hoi_recognition_time.append(results[0]['hoi_recognition_time'] * 1000)
|
| 35 |
+
|
| 36 |
+
preds.extend(list(itertools.chain.from_iterable(utils.all_gather(results))))
|
| 37 |
+
# For avoiding a runtime error, the copy is used
|
| 38 |
+
gts.extend(list(itertools.chain.from_iterable(utils.all_gather(copy.deepcopy(targets)))))
|
| 39 |
+
|
| 40 |
+
print(f"[stats] HOI Recognition Time (avg) : {sum(hoi_recognition_time)/len(hoi_recognition_time):.4f} ms")
|
| 41 |
+
|
| 42 |
+
# gather the stats from all processes
|
| 43 |
+
metric_logger.synchronize_between_processes()
|
| 44 |
+
|
| 45 |
+
img_ids = [img_gts['id'] for img_gts in gts]
|
| 46 |
+
_, indices = np.unique(img_ids, return_index=True)
|
| 47 |
+
preds = [img_preds for i, img_preds in enumerate(preds) if i in indices]
|
| 48 |
+
gts = [img_gts for i, img_gts in enumerate(gts) if i in indices]
|
| 49 |
+
|
| 50 |
+
evaluator = HICOEvaluator(preds, gts, data_loader.dataset.rare_triplets,
|
| 51 |
+
data_loader.dataset.non_rare_triplets, data_loader.dataset.correct_mat)
|
| 52 |
+
|
| 53 |
+
stats = evaluator.evaluate()
|
| 54 |
+
|
| 55 |
+
return stats
|
hotr/engine/evaluator_vcoco.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# HOTR official code : hotr/engine/evaluator_vcoco.py
|
| 3 |
+
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
|
| 4 |
+
# ------------------------------------------------------------------------
|
| 5 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 6 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 7 |
+
# ------------------------------------------------------------------------
|
| 8 |
+
import os
|
| 9 |
+
import torch
|
| 10 |
+
import time
|
| 11 |
+
import datetime
|
| 12 |
+
|
| 13 |
+
import hotr.util.misc as utils
|
| 14 |
+
import hotr.util.logger as loggers
|
| 15 |
+
from hotr.data.evaluators.vcoco_eval import VCocoEvaluator
|
| 16 |
+
from hotr.util.box_ops import rescale_bboxes, rescale_pairs
|
| 17 |
+
|
| 18 |
+
import wandb
|
| 19 |
+
|
| 20 |
+
@torch.no_grad()
|
| 21 |
+
def vcoco_evaluate(model, criterion, postprocessors, data_loader, device, output_dir, thr,args=None):
|
| 22 |
+
model.eval()
|
| 23 |
+
criterion.eval()
|
| 24 |
+
|
| 25 |
+
metric_logger = loggers.MetricLogger(mode="test", delimiter=" ")
|
| 26 |
+
header = 'Evaluation Inference (V-COCO)'
|
| 27 |
+
|
| 28 |
+
print_freq = 1 # len(data_loader)
|
| 29 |
+
res = {}
|
| 30 |
+
hoi_recognition_time = []
|
| 31 |
+
|
| 32 |
+
for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
|
| 33 |
+
samples = samples.to(device)
|
| 34 |
+
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
| 35 |
+
|
| 36 |
+
outputs = model(samples)
|
| 37 |
+
loss_dict = criterion(outputs, targets)
|
| 38 |
+
loss_dict_reduced = utils.reduce_dict(loss_dict) # ddp gathering
|
| 39 |
+
|
| 40 |
+
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
|
| 41 |
+
results = postprocessors['hoi'](outputs, orig_target_sizes, threshold=thr, dataset='vcoco',args=args)
|
| 42 |
+
targets = process_target(targets, orig_target_sizes)
|
| 43 |
+
hoi_recognition_time.append(results[0]['hoi_recognition_time'] * 1000)
|
| 44 |
+
|
| 45 |
+
res.update(
|
| 46 |
+
{target['image_id'].item():\
|
| 47 |
+
{'target': target, 'prediction': output} for target, output in zip(targets, results)
|
| 48 |
+
}
|
| 49 |
+
)
|
| 50 |
+
print(f"[stats] HOI Recognition Time (avg) : {sum(hoi_recognition_time)/len(hoi_recognition_time):.4f} ms")
|
| 51 |
+
|
| 52 |
+
start_time = time.time()
|
| 53 |
+
gather_res = utils.all_gather(res)
|
| 54 |
+
total_res = {}
|
| 55 |
+
for dist_res in gather_res:
|
| 56 |
+
total_res.update(dist_res)
|
| 57 |
+
total_time = time.time() - start_time
|
| 58 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 59 |
+
print(f"[stats] Distributed Gathering Time : {total_time_str}")
|
| 60 |
+
|
| 61 |
+
return total_res
|
| 62 |
+
|
| 63 |
+
def vcoco_accumulate(total_res, args, print_results, wandb_log):
|
| 64 |
+
vcoco_evaluator = VCocoEvaluator(args)
|
| 65 |
+
vcoco_evaluator.update(total_res)
|
| 66 |
+
print(f"[stats] Score Matrix Generation completed!! ")
|
| 67 |
+
|
| 68 |
+
scenario1 = vcoco_evaluator.role_eval1.evaluate(print_results)
|
| 69 |
+
scenario2 = vcoco_evaluator.role_eval2.evaluate(print_results)
|
| 70 |
+
|
| 71 |
+
if wandb_log:
|
| 72 |
+
wandb.log({
|
| 73 |
+
'scenario1': scenario1,
|
| 74 |
+
'scenario2': scenario2
|
| 75 |
+
})
|
| 76 |
+
|
| 77 |
+
return scenario1, scenario2
|
| 78 |
+
|
| 79 |
+
def process_target(targets, target_sizes):
|
| 80 |
+
for idx, (target, target_size) in enumerate(zip(targets, target_sizes)):
|
| 81 |
+
labels = target['labels']
|
| 82 |
+
valid_boxes_inds = (labels > 0)
|
| 83 |
+
|
| 84 |
+
targets[idx]['boxes'] = rescale_bboxes(target['boxes'], target_size) # boxes
|
| 85 |
+
targets[idx]['pair_boxes'] = rescale_pairs(target['pair_boxes'], target_size) # pairs
|
| 86 |
+
|
| 87 |
+
return targets
|
hotr/engine/trainer.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# HOTR official code : engine/trainer.py
|
| 3 |
+
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
|
| 4 |
+
# ------------------------------------------------------------------------
|
| 5 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 6 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 7 |
+
# ------------------------------------------------------------------------
|
| 8 |
+
import math
|
| 9 |
+
import torch
|
| 10 |
+
import sys
|
| 11 |
+
import hotr.util.misc as utils
|
| 12 |
+
import hotr.util.logger as loggers
|
| 13 |
+
from hotr.util.ramp import *
|
| 14 |
+
from typing import Iterable
|
| 15 |
+
import wandb
|
| 16 |
+
|
| 17 |
+
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
|
| 18 |
+
data_loader: Iterable, optimizer: torch.optim.Optimizer,
|
| 19 |
+
device: torch.device, epoch: int, max_epoch: int, ramp_up_epoch: int,rampdown_epoch: int,max_consis_coef: float=1.0,max_norm: float = 0,dataset_file: str = 'coco', log: bool = False):
|
| 20 |
+
model.train()
|
| 21 |
+
criterion.train()
|
| 22 |
+
metric_logger = loggers.MetricLogger(mode="train", delimiter=" ")
|
| 23 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
| 24 |
+
space_fmt = str(len(str(max_epoch)))
|
| 25 |
+
header = 'Epoch [{start_epoch: >{fill}}/{end_epoch}]'.format(start_epoch=epoch+1, end_epoch=max_epoch, fill=space_fmt)
|
| 26 |
+
print_freq = int(len(data_loader)/5)
|
| 27 |
+
|
| 28 |
+
if epoch<=rampdown_epoch:
|
| 29 |
+
consis_coef=sigmoid_rampup(epoch,ramp_up_epoch,max_consis_coef)
|
| 30 |
+
else:
|
| 31 |
+
consis_coef=cosine_rampdown(epoch-rampdown_epoch,max_epoch-rampdown_epoch,max_consis_coef)
|
| 32 |
+
print(consis_coef)
|
| 33 |
+
print(f"\n>>> Epoch #{(epoch+1)}")
|
| 34 |
+
for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
|
| 35 |
+
samples = samples.to(device)
|
| 36 |
+
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
| 37 |
+
|
| 38 |
+
outputs = model(samples)
|
| 39 |
+
loss_dict = criterion(outputs, targets, log)
|
| 40 |
+
#print(loss_dict)
|
| 41 |
+
weight_dict = criterion.weight_dict
|
| 42 |
+
|
| 43 |
+
losses = sum(loss_dict[k] * weight_dict[k]*consis_coef if 'consistency' in k else loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
| 44 |
+
|
| 45 |
+
# reduce losses over all GPUs for logging purposes
|
| 46 |
+
loss_dict_reduced = utils.reduce_dict(loss_dict)
|
| 47 |
+
loss_dict_reduced_unscaled = {f'{k}_unscaled': v
|
| 48 |
+
for k, v in loss_dict_reduced.items()}
|
| 49 |
+
loss_dict_reduced_scaled = {k: v * weight_dict[k]*consis_coef if 'consistency' in k else v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict}
|
| 50 |
+
losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
|
| 51 |
+
loss_value = losses_reduced_scaled.item()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if not math.isfinite(loss_value):
|
| 55 |
+
print("Loss is {}, stopping training".format(loss_value))
|
| 56 |
+
print(loss_dict_reduced)
|
| 57 |
+
sys.exit(1)
|
| 58 |
+
|
| 59 |
+
optimizer.zero_grad()
|
| 60 |
+
losses.backward()
|
| 61 |
+
if max_norm > 0:
|
| 62 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
| 63 |
+
optimizer.step()
|
| 64 |
+
|
| 65 |
+
metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled)
|
| 66 |
+
if "obj_class_error" in loss_dict:
|
| 67 |
+
metric_logger.update(obj_class_error=loss_dict_reduced['obj_class_error'])
|
| 68 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
| 69 |
+
# gather the stats from all processes
|
| 70 |
+
metric_logger.synchronize_between_processes()
|
| 71 |
+
if utils.get_rank() == 0 and log: wandb.log(loss_dict_reduced_scaled)
|
| 72 |
+
print("Averaged stats:", metric_logger)
|
| 73 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
hotr/metrics/utils.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
def count_parameters(model):
|
| 5 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def compute_overlap(a, b):
|
| 9 |
+
if type(a) == torch.Tensor:
|
| 10 |
+
if len(a.shape) == 2:
|
| 11 |
+
area = (b[:, 2] - b[:, 0] + 1) * (b[:, 3] - b[:, 1] + 1)
|
| 12 |
+
|
| 13 |
+
iw = torch.min(a[:, 2].unsqueeze(dim=1), b[:, 2]) - torch.max(a[:, 0].unsqueeze(dim=1), b[:, 0])
|
| 14 |
+
ih = torch.min(a[:, 3].unsqueeze(dim=1), b[:, 3]) - torch.max(a[:, 1].unsqueeze(dim=1), b[:, 1])
|
| 15 |
+
|
| 16 |
+
iw[iw<0] = 0
|
| 17 |
+
ih[ih<0] = 0
|
| 18 |
+
|
| 19 |
+
ua = torch.unsqueeze((a[:, 2] - a[:, 0] + 1) * (a[:, 3] - a[:, 1] + 1), dim=1) + area - iw * ih
|
| 20 |
+
ua[ua < 1e-8] = 1e-8
|
| 21 |
+
|
| 22 |
+
intersection = iw * ih
|
| 23 |
+
|
| 24 |
+
return intersection / ua
|
| 25 |
+
|
| 26 |
+
elif type(a) == np.ndarray:
|
| 27 |
+
if len(a.shape) == 2:
|
| 28 |
+
area = np.expand_dims((b[:, 2] - b[:, 0] + 1) * (b[:, 3] - b[:, 1] + 1), axis=0) #(1, K)
|
| 29 |
+
|
| 30 |
+
iw = np.minimum(np.expand_dims(a[:, 2], axis=1), np.expand_dims(b[:, 2], axis=0)) \
|
| 31 |
+
- np.maximum(np.expand_dims(a[:, 0], axis=1), np.expand_dims(b[:, 0], axis=0)) \
|
| 32 |
+
+ 1
|
| 33 |
+
ih = np.minimum(np.expand_dims(a[:, 3], axis=1), np.expand_dims(b[:, 3], axis=0)) \
|
| 34 |
+
- np.maximum(np.expand_dims(a[:, 1], axis=1), np.expand_dims(b[:, 1], axis=0)) \
|
| 35 |
+
+ 1
|
| 36 |
+
|
| 37 |
+
iw[iw<0] = 0 # (N, K)
|
| 38 |
+
ih[ih<0] = 0 # (N, K)
|
| 39 |
+
|
| 40 |
+
intersection = iw * ih
|
| 41 |
+
|
| 42 |
+
ua = np.expand_dims((a[:, 2] - a[:, 0] + 1) * (a[:, 3] - a[:, 1] + 1), axis=1) + area - intersection
|
| 43 |
+
ua[ua < 1e-8] = 1e-8
|
| 44 |
+
|
| 45 |
+
return intersection / ua
|
| 46 |
+
|
| 47 |
+
elif len(a.shape) == 1:
|
| 48 |
+
area = np.expand_dims((b[:, 2] - b[:, 0] + 1) * (b[:, 3] - b[:, 1] + 1), axis=0) #(1, K)
|
| 49 |
+
|
| 50 |
+
iw = np.minimum(np.expand_dims([a[2]], axis=1), np.expand_dims(b[:, 2], axis=0)) \
|
| 51 |
+
- np.maximum(np.expand_dims([a[0]], axis=1), np.expand_dims(b[:, 0], axis=0))
|
| 52 |
+
ih = np.minimum(np.expand_dims([a[3]], axis=1), np.expand_dims(b[:, 3], axis=0)) \
|
| 53 |
+
- np.maximum(np.expand_dims([a[1]], axis=1), np.expand_dims(b[:, 1], axis=0))
|
| 54 |
+
|
| 55 |
+
iw[iw<0] = 0 # (N, K)
|
| 56 |
+
ih[ih<0] = 0 # (N, K)
|
| 57 |
+
|
| 58 |
+
ua = np.expand_dims([(a[2] - a[0] + 1) * (a[3] - a[1] + 1)], axis=1) + area - iw * ih
|
| 59 |
+
ua[ua < 1e-8] = 1e-8
|
| 60 |
+
|
| 61 |
+
intersection = iw * ih
|
| 62 |
+
|
| 63 |
+
return intersection / ua
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _compute_ap(recall, precision):
|
| 67 |
+
""" Compute the average precision, given the recall and precision curves.
|
| 68 |
+
Code originally from https://github.com/rbgirshick/py-faster-rcnn.
|
| 69 |
+
# Arguments
|
| 70 |
+
recall: The recall curve (list).
|
| 71 |
+
precision: The precision curve (list).
|
| 72 |
+
# Returns
|
| 73 |
+
The average precision as computed in py-faster-rcnn.
|
| 74 |
+
"""
|
| 75 |
+
# correct AP calculation
|
| 76 |
+
# first append sentinel values at the end
|
| 77 |
+
mrec = np.concatenate(([0.], recall, [1.]))
|
| 78 |
+
mpre = np.concatenate(([0.], precision, [0.]))
|
| 79 |
+
|
| 80 |
+
# compute the precision envelope
|
| 81 |
+
for i in range(mpre.size - 1, 0, -1):
|
| 82 |
+
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
|
| 83 |
+
|
| 84 |
+
# to calculate area under PR curve, look for points
|
| 85 |
+
# where X axis (recall) changes value
|
| 86 |
+
i = np.where(mrec[1:] != mrec[:-1])[0]
|
| 87 |
+
|
| 88 |
+
# and sum (\Delta recall) * prec
|
| 89 |
+
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
|
| 90 |
+
return ap
|
hotr/metrics/vcoco/ap_agent.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from hotr.metrics.utils import _compute_ap, compute_overlap
|
| 3 |
+
import pdb
|
| 4 |
+
|
| 5 |
+
class APAgent(object):
|
| 6 |
+
def __init__(self, act_name, iou_threshold=0.5):
|
| 7 |
+
self.act_name = act_name
|
| 8 |
+
self.iou_threshold = iou_threshold
|
| 9 |
+
|
| 10 |
+
self.fp = [np.zeros((0,))] * len(act_name)
|
| 11 |
+
self.tp = [np.zeros((0,))] * len(act_name)
|
| 12 |
+
self.score = [np.zeros((0,))] * len(act_name)
|
| 13 |
+
self.num_ann = [0] * len(act_name)
|
| 14 |
+
|
| 15 |
+
def add_data(self, box, act, cat, i_box, i_act):
|
| 16 |
+
for label in range(len(self.act_name)):
|
| 17 |
+
i_inds = (i_act[:, label] == 1)
|
| 18 |
+
self.num_ann[label] += i_inds.sum()
|
| 19 |
+
|
| 20 |
+
n_pred = box.shape[0]
|
| 21 |
+
if n_pred == 0 : return
|
| 22 |
+
|
| 23 |
+
######################
|
| 24 |
+
valid_i_inds = (i_act[:, 0] != -1) # (n_i, ) # both in COCO & V-COCO
|
| 25 |
+
|
| 26 |
+
overlaps = compute_overlap(box, i_box) # (n_pred, n_i)
|
| 27 |
+
assigned_input = np.argmax(overlaps, axis=1) # (n_pred, )
|
| 28 |
+
v_inds = valid_i_inds[assigned_input] # (n_pred, )
|
| 29 |
+
|
| 30 |
+
n_valid = v_inds.sum()
|
| 31 |
+
|
| 32 |
+
if n_valid == 0 : return
|
| 33 |
+
valid_box = box[v_inds]
|
| 34 |
+
valid_act = act[v_inds]
|
| 35 |
+
valid_cat = cat[v_inds]
|
| 36 |
+
|
| 37 |
+
######################
|
| 38 |
+
s = valid_act * np.expand_dims(valid_cat, axis=1) # (n_v, #act)
|
| 39 |
+
|
| 40 |
+
for label in range(len(self.act_name)):
|
| 41 |
+
inds = np.argsort(s[:, label])[::-1] # (n_v, )
|
| 42 |
+
self.score[label] = np.append(self.score[label], s[inds, label])
|
| 43 |
+
|
| 44 |
+
correct_i_inds = (i_act[:, label] == 1)
|
| 45 |
+
if correct_i_inds.sum() == 0:
|
| 46 |
+
self.tp[label] = np.append(self.tp[label], np.array([0]*n_valid))
|
| 47 |
+
self.fp[label] = np.append(self.fp[label], np.array([1]*n_valid))
|
| 48 |
+
continue
|
| 49 |
+
|
| 50 |
+
overlaps = compute_overlap(valid_box[inds], i_box) # (n_v, n_i)
|
| 51 |
+
assigned_input = np.argmax(overlaps, axis=1) # (n_v, )
|
| 52 |
+
max_overlap = overlaps[range(n_valid), assigned_input] # (n_v, )
|
| 53 |
+
|
| 54 |
+
iou_inds = (max_overlap > self.iou_threshold) & correct_i_inds[assigned_input] # (n_v, )
|
| 55 |
+
|
| 56 |
+
i_nonzero = iou_inds.nonzero()[0]
|
| 57 |
+
i_inds = assigned_input[i_nonzero]
|
| 58 |
+
i_iou = np.unique(i_inds, return_index=True)[1]
|
| 59 |
+
i_tp = i_nonzero[i_iou]
|
| 60 |
+
|
| 61 |
+
t = np.zeros(n_valid, dtype=np.uint8)
|
| 62 |
+
t[i_tp] = 1
|
| 63 |
+
f = 1-t
|
| 64 |
+
|
| 65 |
+
self.tp[label] = np.append(self.tp[label], t)
|
| 66 |
+
self.fp[label] = np.append(self.fp[label], f)
|
| 67 |
+
|
| 68 |
+
def evaluate(self):
|
| 69 |
+
average_precisions = dict()
|
| 70 |
+
for label in range(len(self.act_name)):
|
| 71 |
+
if self.num_ann[label] == 0:
|
| 72 |
+
average_precisions[label] = 0
|
| 73 |
+
continue
|
| 74 |
+
|
| 75 |
+
# sort by score
|
| 76 |
+
indices = np.argsort(-self.score[label])
|
| 77 |
+
self.fp[label] = self.fp[label][indices]
|
| 78 |
+
self.tp[label] = self.tp[label][indices]
|
| 79 |
+
|
| 80 |
+
# compute false positives and true positives
|
| 81 |
+
self.fp[label] = np.cumsum(self.fp[label])
|
| 82 |
+
self.tp[label] = np.cumsum(self.tp[label])
|
| 83 |
+
|
| 84 |
+
# compute recall and precision
|
| 85 |
+
recall = self.tp[label] / self.num_ann[label]
|
| 86 |
+
precision = self.tp[label] / np.maximum(self.tp[label] + self.fp[label], np.finfo(np.float64).eps)
|
| 87 |
+
|
| 88 |
+
# compute average precision
|
| 89 |
+
average_precisions[label] = _compute_ap(recall, precision) * 100
|
| 90 |
+
|
| 91 |
+
print('\n================== AP (Agent) ===================')
|
| 92 |
+
s, n = 0, 0
|
| 93 |
+
|
| 94 |
+
for label in range(len(self.act_name)):
|
| 95 |
+
label_name = "_".join(self.act_name[label].split("_")[1:])
|
| 96 |
+
print('{: >23}: AP = {:0.2f} (#pos = {:d})'.format(label_name, average_precisions[label], self.num_ann[label]))
|
| 97 |
+
s += average_precisions[label]
|
| 98 |
+
n += 1
|
| 99 |
+
|
| 100 |
+
mAP = s/n
|
| 101 |
+
print('| mAP(agent): {:0.2f}'.format(mAP))
|
| 102 |
+
print('----------------------------------------------------')
|
| 103 |
+
|
| 104 |
+
return mAP
|
hotr/metrics/vcoco/ap_role.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from hotr.metrics.utils import _compute_ap, compute_overlap
|
| 4 |
+
|
| 5 |
+
class APRole(object):
|
| 6 |
+
def __init__(self, act_name, scenario_flag=True, iou_threshold=0.5):
|
| 7 |
+
self.act_name = act_name
|
| 8 |
+
self.iou_threshold = iou_threshold
|
| 9 |
+
|
| 10 |
+
self.scenario_flag = scenario_flag
|
| 11 |
+
# scenario_1 : True
|
| 12 |
+
# scenario_2 : False
|
| 13 |
+
|
| 14 |
+
self.fp = [np.zeros((0,))] * len(act_name)
|
| 15 |
+
self.tp = [np.zeros((0,))] * len(act_name)
|
| 16 |
+
self.score = [np.zeros((0,))] * len(act_name)
|
| 17 |
+
self.num_ann = [0] * len(act_name)
|
| 18 |
+
|
| 19 |
+
def add_data(self, h_box, o_box, score, i_box, i_act, p_box, p_act):
|
| 20 |
+
# i_box, i_act : to check if only in COCO
|
| 21 |
+
for label in range(len(self.act_name)):
|
| 22 |
+
p_inds = (p_act[:, label] == 1)
|
| 23 |
+
self.num_ann[label] += p_inds.sum()
|
| 24 |
+
|
| 25 |
+
if h_box.shape[0] == 0 : return # if no prediction, just return
|
| 26 |
+
# COCO (O), V-COCO (X) __or__ collater, no ann in image => ignore
|
| 27 |
+
|
| 28 |
+
valid_i_inds = (i_act[:, 0] != -1) # (n_i, )
|
| 29 |
+
overlaps = compute_overlap(h_box, i_box) # (n_h, n_i)
|
| 30 |
+
assigned_input = np.argmax(overlaps, axis=1) # (n_h, )
|
| 31 |
+
v_inds = valid_i_inds[assigned_input] # (n_h, )
|
| 32 |
+
|
| 33 |
+
h_box = h_box[v_inds]
|
| 34 |
+
score = score[:, v_inds, :]
|
| 35 |
+
if h_box.shape[0] == 0 : return
|
| 36 |
+
n_h = h_box.shape[0]
|
| 37 |
+
|
| 38 |
+
valid_p_inds = (p_act[:, 0] != -1) | (p_box[:, 0] != -1)
|
| 39 |
+
p_act = p_act[valid_p_inds]
|
| 40 |
+
p_box = p_box[valid_p_inds]
|
| 41 |
+
|
| 42 |
+
n_o = o_box.shape[0]
|
| 43 |
+
if n_o == 0:
|
| 44 |
+
# no prediction for object
|
| 45 |
+
score = score.squeeze(axis=2) # (#act, n_h)
|
| 46 |
+
|
| 47 |
+
for label in range(len(self.act_name)):
|
| 48 |
+
h_inds = np.argsort(score[label])[::-1] # (n_h, )
|
| 49 |
+
self.score[label] = np.append(self.score[label], score[label, h_inds])
|
| 50 |
+
|
| 51 |
+
p_inds = (p_act[:, label] == 1)
|
| 52 |
+
if p_inds.sum() == 0:
|
| 53 |
+
self.tp[label] = np.append(self.tp[label], np.array([0]*n_h))
|
| 54 |
+
self.fp[label] = np.append(self.fp[label], np.array([1]*n_h))
|
| 55 |
+
continue
|
| 56 |
+
|
| 57 |
+
h_overlaps = compute_overlap(h_box[h_inds], p_box[p_inds, :4]) # (n_h, n_p)
|
| 58 |
+
assigned_p = np.argmax(h_overlaps, axis=1) # (n_h, )
|
| 59 |
+
h_max_overlap = h_overlaps[range(n_h), assigned_p] # (n_h, )
|
| 60 |
+
|
| 61 |
+
o_overlaps = compute_overlap(np.zeros((n_h, 4)), p_box[p_inds][assigned_p, 4:8])
|
| 62 |
+
o_overlaps = np.diag(o_overlaps) # (n_h, )
|
| 63 |
+
|
| 64 |
+
no_role_inds = (p_box[p_inds][assigned_p, 4] == -1) # (n_h, )
|
| 65 |
+
# human (o), action (o), no object in actual image
|
| 66 |
+
|
| 67 |
+
h_iou_inds = (h_max_overlap > self.iou_threshold) # (n_h, )
|
| 68 |
+
o_iou_inds = (o_overlaps > self.iou_threshold) # (n_h, )
|
| 69 |
+
|
| 70 |
+
# scenario1 is not considered (already no object)
|
| 71 |
+
o_iou_inds[no_role_inds] = 1
|
| 72 |
+
|
| 73 |
+
iou_inds = (h_iou_inds & o_iou_inds)
|
| 74 |
+
p_nonzero = iou_inds.nonzero()[0]
|
| 75 |
+
p_inds = assigned_p[p_nonzero]
|
| 76 |
+
p_iou = np.unique(p_inds, return_index=True)[1]
|
| 77 |
+
p_tp = p_nonzero[p_iou]
|
| 78 |
+
|
| 79 |
+
t = np.zeros(n_h, dtype=np.uint8)
|
| 80 |
+
t[p_tp] = 1
|
| 81 |
+
f = 1-t
|
| 82 |
+
|
| 83 |
+
self.tp[label] = np.append(self.tp[label], t)
|
| 84 |
+
self.fp[label] = np.append(self.fp[label], f)
|
| 85 |
+
|
| 86 |
+
else:
|
| 87 |
+
s_obj_argmax = np.argmax(score.reshape(-1, n_o), axis=1).reshape(-1, n_h) # (#act, n_h)
|
| 88 |
+
s_obj_max = np.max(score.reshape(-1, n_o), axis=1).reshape(-1, n_h) # (#act, n_h)
|
| 89 |
+
|
| 90 |
+
h_overlaps = compute_overlap(h_box, p_box[:, :4]) # (n_h, n_p)
|
| 91 |
+
|
| 92 |
+
for label in range(len(self.act_name)):
|
| 93 |
+
h_inds = np.argsort(s_obj_max[label])[::-1] # (n_h, )
|
| 94 |
+
self.score[label] = np.append(self.score[label], s_obj_max[label, h_inds])
|
| 95 |
+
|
| 96 |
+
p_inds = (p_act[:, label] == 1) # (n_p, )
|
| 97 |
+
if p_inds.sum() == 0:
|
| 98 |
+
self.tp[label] = np.append(self.tp[label], np.array([0]*n_h))
|
| 99 |
+
self.fp[label] = np.append(self.fp[label], np.array([1]*n_h))
|
| 100 |
+
continue
|
| 101 |
+
|
| 102 |
+
h_overlaps = compute_overlap(h_box[h_inds], p_box[:, :4]) # (n_h, n_p) # match for all hboxes
|
| 103 |
+
h_max_overlap = np.max(h_overlaps, axis=1) # (n_h, ) # get the max overlap for hbox
|
| 104 |
+
|
| 105 |
+
# for same human, multiple pairs exist. find the human box that has the same idx with max overlap hbox.
|
| 106 |
+
h_max_temp = np.expand_dims(h_max_overlap, axis=1)
|
| 107 |
+
h_over_thresh = (h_overlaps == h_max_temp) # (n_h, n_p)
|
| 108 |
+
h_over_thresh = h_over_thresh & np.expand_dims(p_inds, axis=0) # (n_h, n_p) # find only for current act
|
| 109 |
+
|
| 110 |
+
h_valid = h_over_thresh.sum(axis=1)>0 # (n_h, ) # at least one is True
|
| 111 |
+
# h_valid -> if all is False, then argmax becomes 0. <- prevent
|
| 112 |
+
assigned_p = np.argmax(h_over_thresh, axis=1) # (n_h, ) # p only for current act
|
| 113 |
+
|
| 114 |
+
o_mapping_box = o_box[s_obj_argmax[label]][h_inds] # (n_h, ) # find where T is.
|
| 115 |
+
p_mapping_box = p_box[assigned_p, 4:8] # (n_h, 4)
|
| 116 |
+
|
| 117 |
+
o_overlaps = compute_overlap(o_mapping_box, p_mapping_box)
|
| 118 |
+
o_overlaps = np.diag(o_overlaps) # (n_h, )
|
| 119 |
+
o_overlaps.setflags(write=1)
|
| 120 |
+
if (~h_valid).sum() > 0:
|
| 121 |
+
o_overlaps[~h_valid] = 0 # (n_h, )
|
| 122 |
+
|
| 123 |
+
no_role_inds = (p_box[assigned_p, 4] == -1) # (n_h, )
|
| 124 |
+
nan_box_inds = np.all(o_mapping_box == 0, axis=1) | np.all(np.isnan(o_mapping_box), axis=1)
|
| 125 |
+
no_role_inds = no_role_inds & h_valid
|
| 126 |
+
nan_box_inds = nan_box_inds & h_valid
|
| 127 |
+
|
| 128 |
+
h_iou_inds = (h_max_overlap > self.iou_threshold) # (n_h, )
|
| 129 |
+
o_iou_inds = (o_overlaps > self.iou_threshold) # (n_h, )
|
| 130 |
+
|
| 131 |
+
if self.scenario_flag: # scenario_1
|
| 132 |
+
o_iou_inds[no_role_inds & nan_box_inds] = 1
|
| 133 |
+
o_iou_inds[no_role_inds & ~nan_box_inds] = 0
|
| 134 |
+
else: # scenario_2
|
| 135 |
+
o_iou_inds[no_role_inds] = 1
|
| 136 |
+
|
| 137 |
+
iou_inds = (h_iou_inds & o_iou_inds)
|
| 138 |
+
p_nonzero = iou_inds.nonzero()[0]
|
| 139 |
+
p_inds = assigned_p[p_nonzero]
|
| 140 |
+
p_iou = np.unique(p_inds, return_index=True)[1]
|
| 141 |
+
p_tp = p_nonzero[p_iou]
|
| 142 |
+
|
| 143 |
+
t = np.zeros(n_h, dtype=np.uint8)
|
| 144 |
+
t[p_tp] = 1
|
| 145 |
+
f = 1-t
|
| 146 |
+
|
| 147 |
+
self.tp[label] = np.append(self.tp[label], t)
|
| 148 |
+
self.fp[label] = np.append(self.fp[label], f)
|
| 149 |
+
|
| 150 |
+
def evaluate(self, print_log=False):
|
| 151 |
+
average_precisions = dict()
|
| 152 |
+
role_num = 1 if self.scenario_flag else 2
|
| 153 |
+
for label in range(len(self.act_name)):
|
| 154 |
+
|
| 155 |
+
# sort by score
|
| 156 |
+
indices = np.argsort(-self.score[label])
|
| 157 |
+
self.fp[label] = self.fp[label][indices]
|
| 158 |
+
self.tp[label] = self.tp[label][indices]
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if self.num_ann[label] == 0:
|
| 162 |
+
average_precisions[label] = 0
|
| 163 |
+
continue
|
| 164 |
+
|
| 165 |
+
# compute false positives and true positives
|
| 166 |
+
self.fp[label] = np.cumsum(self.fp[label])
|
| 167 |
+
self.tp[label] = np.cumsum(self.tp[label])
|
| 168 |
+
|
| 169 |
+
# compute recall and precision
|
| 170 |
+
recall = self.tp[label] / self.num_ann[label]
|
| 171 |
+
precision = self.tp[label] / np.maximum(self.tp[label] + self.fp[label], np.finfo(np.float64).eps)
|
| 172 |
+
|
| 173 |
+
# compute average precision
|
| 174 |
+
average_precisions[label] = _compute_ap(recall, precision) * 100
|
| 175 |
+
|
| 176 |
+
if print_log: print(f'\n============= AP (Role scenario_{role_num}) ==============')
|
| 177 |
+
s, n = 0, 0
|
| 178 |
+
|
| 179 |
+
for label in range(len(self.act_name)):
|
| 180 |
+
if 'point' in self.act_name[label]:
|
| 181 |
+
continue
|
| 182 |
+
label_name = "_".join(self.act_name[label].split("_")[1:])
|
| 183 |
+
if print_log: print('{: >23}: AP = {:0.2f} (#pos = {:d})'.format(label_name, average_precisions[label], self.num_ann[label]))
|
| 184 |
+
if self.num_ann[label] != 0 :
|
| 185 |
+
s += average_precisions[label]
|
| 186 |
+
n += 1
|
| 187 |
+
|
| 188 |
+
mAP = s/n
|
| 189 |
+
if print_log:
|
| 190 |
+
print('| mAP(role scenario_{:d}): {:0.2f}'.format(role_num, mAP))
|
| 191 |
+
print('----------------------------------------------------')
|
| 192 |
+
|
| 193 |
+
return mAP
|
hotr/models/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
from .detr import build
|
| 3 |
+
|
| 4 |
+
def build_model(args):
|
| 5 |
+
return build(args)
|
hotr/models/backbone.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
"""
|
| 3 |
+
Backbone modules.
|
| 4 |
+
"""
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torchvision
|
| 10 |
+
from torch import nn
|
| 11 |
+
from torchvision.models._utils import IntermediateLayerGetter
|
| 12 |
+
from typing import Dict, List
|
| 13 |
+
|
| 14 |
+
from hotr.util.misc import NestedTensor, is_main_process
|
| 15 |
+
|
| 16 |
+
from .position_encoding import build_position_encoding
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class FrozenBatchNorm2d(torch.nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
| 22 |
+
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
| 23 |
+
without which any other models than torchvision.models.resnet[18,34,50,101]
|
| 24 |
+
produce nans.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, n):
|
| 28 |
+
super(FrozenBatchNorm2d, self).__init__()
|
| 29 |
+
self.register_buffer("weight", torch.ones(n))
|
| 30 |
+
self.register_buffer("bias", torch.zeros(n))
|
| 31 |
+
self.register_buffer("running_mean", torch.zeros(n))
|
| 32 |
+
self.register_buffer("running_var", torch.ones(n))
|
| 33 |
+
|
| 34 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
| 35 |
+
missing_keys, unexpected_keys, error_msgs):
|
| 36 |
+
num_batches_tracked_key = prefix + 'num_batches_tracked'
|
| 37 |
+
if num_batches_tracked_key in state_dict:
|
| 38 |
+
del state_dict[num_batches_tracked_key]
|
| 39 |
+
|
| 40 |
+
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
| 41 |
+
state_dict, prefix, local_metadata, strict,
|
| 42 |
+
missing_keys, unexpected_keys, error_msgs)
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
# move reshapes to the beginning
|
| 46 |
+
# to make it fuser-friendly
|
| 47 |
+
w = self.weight.reshape(1, -1, 1, 1)
|
| 48 |
+
b = self.bias.reshape(1, -1, 1, 1)
|
| 49 |
+
rv = self.running_var.reshape(1, -1, 1, 1)
|
| 50 |
+
rm = self.running_mean.reshape(1, -1, 1, 1)
|
| 51 |
+
eps = 1e-5
|
| 52 |
+
scale = w * (rv + eps).rsqrt()
|
| 53 |
+
bias = b - rm * scale
|
| 54 |
+
return x * scale + bias
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class BackboneBase(nn.Module):
|
| 58 |
+
|
| 59 |
+
def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
|
| 60 |
+
super().__init__()
|
| 61 |
+
for name, parameter in backbone.named_parameters():
|
| 62 |
+
if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
| 63 |
+
parameter.requires_grad_(False)
|
| 64 |
+
if return_interm_layers:
|
| 65 |
+
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
| 66 |
+
else:
|
| 67 |
+
return_layers = {'layer4': "0"}
|
| 68 |
+
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
| 69 |
+
self.num_channels = num_channels
|
| 70 |
+
|
| 71 |
+
def forward(self, tensor_list: NestedTensor):
|
| 72 |
+
xs = self.body(tensor_list.tensors)
|
| 73 |
+
out: Dict[str, NestedTensor] = {}
|
| 74 |
+
for name, x in xs.items():
|
| 75 |
+
m = tensor_list.mask
|
| 76 |
+
assert m is not None
|
| 77 |
+
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
| 78 |
+
out[name] = NestedTensor(x, mask)
|
| 79 |
+
return out
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class Backbone(BackboneBase):
|
| 83 |
+
"""ResNet backbone with frozen BatchNorm."""
|
| 84 |
+
def __init__(self, name: str,
|
| 85 |
+
train_backbone: bool,
|
| 86 |
+
return_interm_layers: bool,
|
| 87 |
+
dilation: bool):
|
| 88 |
+
backbone = getattr(torchvision.models, name)(
|
| 89 |
+
replace_stride_with_dilation=[False, False, dilation],
|
| 90 |
+
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
|
| 91 |
+
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
|
| 92 |
+
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class Joiner(nn.Sequential):
|
| 96 |
+
def __init__(self, backbone, position_embedding):
|
| 97 |
+
super().__init__(backbone, position_embedding)
|
| 98 |
+
|
| 99 |
+
def forward(self, tensor_list: NestedTensor):
|
| 100 |
+
xs = self[0](tensor_list)
|
| 101 |
+
out: List[NestedTensor] = []
|
| 102 |
+
pos = []
|
| 103 |
+
for name, x in xs.items():
|
| 104 |
+
out.append(x)
|
| 105 |
+
# position encoding
|
| 106 |
+
pos.append(self[1](x).to(x.tensors.dtype))
|
| 107 |
+
|
| 108 |
+
return out, pos
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def build_backbone(args):
|
| 112 |
+
position_embedding = build_position_encoding(args)
|
| 113 |
+
train_backbone = args.lr_backbone > 0
|
| 114 |
+
return_interm_layers = False # args.masks
|
| 115 |
+
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
|
| 116 |
+
model = Joiner(backbone, position_embedding)
|
| 117 |
+
model.num_channels = backbone.num_channels
|
| 118 |
+
return model
|
hotr/models/criterion.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# HOTR official code : main.py
|
| 3 |
+
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
|
| 4 |
+
# ------------------------------------------------------------------------
|
| 5 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 6 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 7 |
+
# ------------------------------------------------------------------------
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import copy
|
| 11 |
+
import numpy as np
|
| 12 |
+
import itertools
|
| 13 |
+
from torch import nn
|
| 14 |
+
|
| 15 |
+
from hotr.util import box_ops
|
| 16 |
+
from hotr.util.misc import (accuracy, get_world_size, is_dist_avail_and_initialized)
|
| 17 |
+
|
| 18 |
+
class SetCriterion(nn.Module):
|
| 19 |
+
""" This class computes the loss for DETR.
|
| 20 |
+
The process happens in two steps:
|
| 21 |
+
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
| 22 |
+
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
| 23 |
+
"""
|
| 24 |
+
def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses, num_actions=None, HOI_losses=None, HOI_matcher=None, args=None):
|
| 25 |
+
""" Create the criterion.
|
| 26 |
+
Parameters:
|
| 27 |
+
num_classes: number of object categories, omitting the special no-object category
|
| 28 |
+
matcher: module able to compute a matching between targets and proposals
|
| 29 |
+
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
| 30 |
+
eos_coef: relative classification weight applied to the no-object category
|
| 31 |
+
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
| 32 |
+
"""
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.num_classes = num_classes
|
| 35 |
+
self.matcher = matcher
|
| 36 |
+
self.weight_dict = weight_dict
|
| 37 |
+
self.losses = losses
|
| 38 |
+
self.eos_coef=eos_coef
|
| 39 |
+
|
| 40 |
+
self.HOI_losses = HOI_losses
|
| 41 |
+
self.HOI_matcher = HOI_matcher
|
| 42 |
+
self.use_consis=args.use_consis & len(args.augpath_name)>0
|
| 43 |
+
self.num_path = 1+len(args.augpath_name)
|
| 44 |
+
if args:
|
| 45 |
+
self.HOI_eos_coef = args.hoi_eos_coef
|
| 46 |
+
if args.dataset_file == 'vcoco':
|
| 47 |
+
self.invalid_ids = args.invalid_ids
|
| 48 |
+
self.valid_ids = np.concatenate((args.valid_ids,[-1]), axis=0) # no interaction
|
| 49 |
+
elif args.dataset_file == 'hico-det':
|
| 50 |
+
self.invalid_ids = []
|
| 51 |
+
self.valid_ids = list(range(num_actions)) + [-1]
|
| 52 |
+
|
| 53 |
+
# for targets
|
| 54 |
+
self.num_tgt_classes = len(args.valid_obj_ids)
|
| 55 |
+
tgt_empty_weight = torch.ones(self.num_tgt_classes + 1)
|
| 56 |
+
tgt_empty_weight[-1] = self.HOI_eos_coef
|
| 57 |
+
self.register_buffer('tgt_empty_weight', tgt_empty_weight)
|
| 58 |
+
self.dataset_file = args.dataset_file
|
| 59 |
+
|
| 60 |
+
empty_weight = torch.ones(self.num_classes + 1)
|
| 61 |
+
empty_weight[-1] = eos_coef
|
| 62 |
+
self.register_buffer('empty_weight', empty_weight)
|
| 63 |
+
|
| 64 |
+
#######################################################################################################################
|
| 65 |
+
# * DETR Losses
|
| 66 |
+
#######################################################################################################################
|
| 67 |
+
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
|
| 68 |
+
"""Classification loss (NLL)
|
| 69 |
+
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
|
| 70 |
+
"""
|
| 71 |
+
assert 'pred_logits' in outputs
|
| 72 |
+
src_logits = outputs['pred_logits']
|
| 73 |
+
|
| 74 |
+
idx = self._get_src_permutation_idx(indices)
|
| 75 |
+
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
|
| 76 |
+
target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device)
|
| 77 |
+
target_classes[idx] = target_classes_o
|
| 78 |
+
|
| 79 |
+
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
|
| 80 |
+
losses = {'loss_ce': loss_ce}
|
| 81 |
+
|
| 82 |
+
if log:
|
| 83 |
+
# TODO this should probably be a separate loss, not hacked in this one here
|
| 84 |
+
losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
|
| 85 |
+
return losses
|
| 86 |
+
|
| 87 |
+
@torch.no_grad()
|
| 88 |
+
def loss_cardinality(self, outputs, targets, indices, num_boxes):
|
| 89 |
+
""" Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
|
| 90 |
+
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
|
| 91 |
+
"""
|
| 92 |
+
pred_logits = outputs['pred_logits']
|
| 93 |
+
device = pred_logits.device
|
| 94 |
+
tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
|
| 95 |
+
# Count the number of predictions that are NOT "no-object" (which is the last class)
|
| 96 |
+
card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
|
| 97 |
+
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
|
| 98 |
+
losses = {'cardinality_error': card_err}
|
| 99 |
+
return losses
|
| 100 |
+
|
| 101 |
+
def loss_boxes(self, outputs, targets, indices, num_boxes):
|
| 102 |
+
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
|
| 103 |
+
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
|
| 104 |
+
The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
|
| 105 |
+
"""
|
| 106 |
+
assert 'pred_boxes' in outputs
|
| 107 |
+
idx = self._get_src_permutation_idx(indices)
|
| 108 |
+
src_boxes = outputs['pred_boxes'][idx]
|
| 109 |
+
target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
| 110 |
+
|
| 111 |
+
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
|
| 112 |
+
|
| 113 |
+
losses = {}
|
| 114 |
+
losses['loss_bbox'] = loss_bbox.sum() / num_boxes
|
| 115 |
+
|
| 116 |
+
loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
|
| 117 |
+
box_ops.box_cxcywh_to_xyxy(src_boxes),
|
| 118 |
+
box_ops.box_cxcywh_to_xyxy(target_boxes)))
|
| 119 |
+
losses['loss_giou'] = loss_giou.sum() / num_boxes
|
| 120 |
+
return losses
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
#######################################################################################################################
|
| 124 |
+
# * HOTR Losses
|
| 125 |
+
#######################################################################################################################
|
| 126 |
+
# >>> HOI Losses 1 : HO Pointer
|
| 127 |
+
def loss_pair_labels(self, outputs, targets, hoi_indices, num_boxes,use_consis, log=False):
|
| 128 |
+
assert ('pred_hidx' in outputs and 'pred_oidx' in outputs)
|
| 129 |
+
outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
|
| 130 |
+
nu,q,hd=outputs['pred_hidx'].shape
|
| 131 |
+
src_hidx = outputs['pred_hidx'].view(self.num_path,nu//self.num_path,q,-1).transpose(0,1).flatten(0,1)
|
| 132 |
+
src_oidx = outputs['pred_oidx'].view(self.num_path,nu//self.num_path,q,-1).transpose(0,1).flatten(0,1)
|
| 133 |
+
hoi_ind=list(itertools.chain.from_iterable(hoi_indices))
|
| 134 |
+
|
| 135 |
+
idx = self._get_src_permutation_idx(hoi_ind)
|
| 136 |
+
|
| 137 |
+
target_hidx_classes = torch.full(src_hidx.shape[:2], -1, dtype=torch.int64, device=src_hidx.device)
|
| 138 |
+
target_oidx_classes = torch.full(src_oidx.shape[:2], -1, dtype=torch.int64, device=src_oidx.device)
|
| 139 |
+
|
| 140 |
+
# H Pointer loss
|
| 141 |
+
target_classes_h = torch.cat([t["h_labels"][J] for t, hoi_indice in zip(targets, hoi_indices) for (_,J) in hoi_indice])
|
| 142 |
+
target_hidx_classes[idx] = target_classes_h
|
| 143 |
+
|
| 144 |
+
# O Pointer loss
|
| 145 |
+
target_classes_o = torch.cat([t["o_labels"][J] for t, hoi_indice in zip(targets, hoi_indices) for (_,J) in hoi_indice])
|
| 146 |
+
target_oidx_classes[idx] = target_classes_o
|
| 147 |
+
|
| 148 |
+
loss_h = F.cross_entropy(src_hidx.transpose(1, 2), target_hidx_classes, ignore_index=-1)
|
| 149 |
+
loss_o = F.cross_entropy(src_oidx.transpose(1, 2), target_oidx_classes, ignore_index=-1)
|
| 150 |
+
|
| 151 |
+
#Consistency loss
|
| 152 |
+
if use_consis:
|
| 153 |
+
consistency_idxs=[self._get_consistency_src_permutation_idx(hoi_indice) for hoi_indice in hoi_indices ]
|
| 154 |
+
src_hidx_inputs=[F.softmax(src_hidx.view(-1,self.num_path,q,hd)[i][consistency_idx[0]],-1) for i,consistency_idx in enumerate(consistency_idxs)]
|
| 155 |
+
src_hidx_targets=[F.softmax(src_hidx.view(-1,self.num_path,q,hd)[i][consistency_idx[1]],-1) for i,consistency_idx in enumerate(consistency_idxs)]
|
| 156 |
+
src_oidx_inputs=[F.softmax(src_oidx.view(-1,self.num_path,q,hd)[i][consistency_idx[0]],-1) for i,consistency_idx in enumerate(consistency_idxs)]
|
| 157 |
+
src_oidx_targets=[F.softmax(src_oidx.view(-1,self.num_path,q,hd)[i][consistency_idx[1]],-1) for i,consistency_idx in enumerate(consistency_idxs)]
|
| 158 |
+
|
| 159 |
+
loss_h_consistency=[0.5*(F.kl_div(src_hidx_input.log(),src_hidx_target.clone().detach(),reduction='batchmean')+F.kl_div(src_hidx_target.log(),src_hidx_input.clone().detach(),reduction='batchmean')) for src_hidx_input,src_hidx_target in zip(src_hidx_inputs,src_hidx_targets)]
|
| 160 |
+
loss_o_consistency=[0.5*(F.kl_div(src_oidx_input.log(),src_oidx_target.clone().detach(),reduction='batchmean')+F.kl_div(src_oidx_target.log(),src_oidx_input.clone().detach(),reduction='batchmean')) for src_oidx_input,src_oidx_target in zip(src_oidx_inputs,src_oidx_targets)]
|
| 161 |
+
|
| 162 |
+
loss_h_consistency=torch.mean(torch.stack(loss_h_consistency))
|
| 163 |
+
loss_o_consistency=torch.mean(torch.stack(loss_o_consistency))
|
| 164 |
+
|
| 165 |
+
losses = {'loss_hidx': loss_h, 'loss_oidx': loss_o,'loss_h_consistency':loss_h_consistency,'loss_o_consistency':loss_o_consistency}
|
| 166 |
+
else:
|
| 167 |
+
losses = {'loss_hidx': loss_h, 'loss_oidx': loss_o}
|
| 168 |
+
|
| 169 |
+
return losses
|
| 170 |
+
|
| 171 |
+
# >>> HOI Losses 2 : pair actions
|
| 172 |
+
def loss_pair_actions(self, outputs, targets, hoi_indices, num_boxes,use_consis):
|
| 173 |
+
assert 'pred_actions' in outputs
|
| 174 |
+
src_actions = outputs['pred_actions'].flatten(end_dim=1)
|
| 175 |
+
hoi_ind=list(itertools.chain.from_iterable(hoi_indices))
|
| 176 |
+
# idx = self._get_src_permutation_idx(hoi_indices)
|
| 177 |
+
idx = self._get_src_permutation_idx(hoi_ind)
|
| 178 |
+
|
| 179 |
+
# Construct Target --------------------------------------------------------------------------------------------------------------
|
| 180 |
+
target_classes_o = torch.cat([t["pair_actions"][J] for t, hoi_indice in zip(targets, hoi_indices) for (_,J) in hoi_indice])
|
| 181 |
+
target_classes = torch.full(src_actions.shape, 0, dtype=torch.float32, device=src_actions.device)
|
| 182 |
+
target_classes[..., -1] = 1 # the last index for no-interaction is '1' if a label exists
|
| 183 |
+
|
| 184 |
+
pos_classes = torch.full(target_classes[idx].shape, 0, dtype=torch.float32, device=src_actions.device) # else, the last index for no-interaction is '0'
|
| 185 |
+
pos_classes[:, :-1] = target_classes_o.float()
|
| 186 |
+
target_classes[idx] = pos_classes
|
| 187 |
+
# --------------------------------------------------------------------------------------------------------------------------------
|
| 188 |
+
|
| 189 |
+
# BCE Loss -----------------------------------------------------------------------------------------------------------------------
|
| 190 |
+
logits = src_actions.sigmoid()
|
| 191 |
+
loss_bce = F.binary_cross_entropy(logits[..., self.valid_ids], target_classes[..., self.valid_ids], reduction='none')
|
| 192 |
+
p_t = logits[..., self.valid_ids] * target_classes[..., self.valid_ids] + (1 - logits[..., self.valid_ids]) * (1 - target_classes[..., self.valid_ids])
|
| 193 |
+
loss_bce = ((1-p_t)**2 * loss_bce)
|
| 194 |
+
alpha_t = 0.25 * target_classes[..., self.valid_ids] + (1 - 0.25) * (1 - target_classes[..., self.valid_ids])
|
| 195 |
+
loss_focal = alpha_t * loss_bce
|
| 196 |
+
loss_act = loss_focal.sum() / max(target_classes[..., self.valid_ids[:-1]].sum(), 1)
|
| 197 |
+
# --------------------------------------------------------------------------------------------------------------------------------
|
| 198 |
+
|
| 199 |
+
#Consistency loss
|
| 200 |
+
if use_consis:
|
| 201 |
+
consistency_idxs=[self._get_consistency_src_permutation_idx(hoi_indice) for hoi_indice in hoi_indices]
|
| 202 |
+
src_action_inputs=[F.logsigmoid(outputs['pred_actions'][i][consistency_idx[0]]) for i,consistency_idx in enumerate(consistency_idxs)]
|
| 203 |
+
src_action_targets=[F.logsigmoid(outputs['pred_actions'][i][consistency_idx[1]]) for i,consistency_idx in enumerate(consistency_idxs)]
|
| 204 |
+
|
| 205 |
+
loss_action_consistency=[F.mse_loss(src_action_input,src_action_target) for src_action_input,src_action_target in zip(src_action_inputs,src_action_targets)]
|
| 206 |
+
loss_action_consistency=torch.mean(torch.stack(loss_action_consistency))
|
| 207 |
+
# import pdb;pdb.set_trace()
|
| 208 |
+
losses = {'loss_act': loss_act,'loss_act_consistency':loss_action_consistency}
|
| 209 |
+
else:
|
| 210 |
+
losses = {'loss_act': loss_act}
|
| 211 |
+
return losses
|
| 212 |
+
|
| 213 |
+
# HOI Losses 3 : action targets
|
| 214 |
+
def loss_pair_targets(self, outputs, targets, hoi_indices, num_interactions,use_consis, log=True):
|
| 215 |
+
assert 'pred_obj_logits' in outputs
|
| 216 |
+
src_logits = outputs['pred_obj_logits']
|
| 217 |
+
nu,q,hd=outputs['pred_obj_logits'].shape
|
| 218 |
+
hoi_ind=list(itertools.chain.from_iterable(hoi_indices))
|
| 219 |
+
idx = self._get_src_permutation_idx(hoi_ind)
|
| 220 |
+
|
| 221 |
+
target_classes_o = torch.cat([t['pair_targets'][J] for t, hoi_indice in zip(targets, hoi_indices) for (_,J) in hoi_indice])
|
| 222 |
+
pad_tgt = -1 # src_logits.shape[2]-1
|
| 223 |
+
target_classes = torch.full(src_logits.shape[:2], pad_tgt, dtype=torch.int64, device=src_logits.device)
|
| 224 |
+
target_classes[idx] = target_classes_o
|
| 225 |
+
|
| 226 |
+
loss_obj_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.tgt_empty_weight, ignore_index=-1)
|
| 227 |
+
|
| 228 |
+
#consistency
|
| 229 |
+
if use_consis:
|
| 230 |
+
consistency_idxs=[self._get_consistency_src_permutation_idx(hoi_indice) for hoi_indice in hoi_indices]
|
| 231 |
+
src_logits_inputs=[F.softmax(src_logits.view(-1,self.num_path,q,hd)[i][consistency_idx[0]],-1) for i,consistency_idx in enumerate(consistency_idxs)]
|
| 232 |
+
src_logits_targets=[F.softmax(src_logits.view(-1,self.num_path,q,hd)[i][consistency_idx[1]],-1) for i,consistency_idx in enumerate(consistency_idxs)]
|
| 233 |
+
loss_tgt_consistency=[0.5*(F.kl_div(src_logit_input.log(),src_logit_target.clone().detach(),reduction='batchmean')+F.kl_div(src_logit_target.log(),src_logit_input.clone().detach(),reduction='batchmean')) for src_logit_input,src_logit_target in zip(src_logits_inputs,src_logits_targets)]
|
| 234 |
+
loss_tgt_consistency=torch.mean(torch.stack(loss_tgt_consistency))
|
| 235 |
+
losses = {'loss_tgt': loss_obj_ce,"loss_tgt_label_consistency":loss_tgt_consistency}
|
| 236 |
+
else:
|
| 237 |
+
losses = {'loss_tgt': loss_obj_ce}
|
| 238 |
+
if log:
|
| 239 |
+
ignore_idx = (target_classes_o != -1)
|
| 240 |
+
losses['obj_class_error'] = 100 - accuracy(src_logits[idx][ignore_idx, :-1], target_classes_o[ignore_idx])[0]
|
| 241 |
+
# losses['obj_class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
|
| 242 |
+
return losses
|
| 243 |
+
|
| 244 |
+
def _get_src_permutation_idx(self, indices):
|
| 245 |
+
# permute predictions following indices
|
| 246 |
+
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
| 247 |
+
src_idx = torch.cat([src for (src, _) in indices])
|
| 248 |
+
return batch_idx, src_idx
|
| 249 |
+
|
| 250 |
+
def _get_consistency_src_permutation_idx(self, indices):
|
| 251 |
+
all_tgt=torch.cat([j for(_,j) in indices]).unique()
|
| 252 |
+
path_idxs=[torch.cat([torch.tensor([i]) for i,(_,t)in enumerate(indices) if (t==tgt).any()]) for tgt in all_tgt]
|
| 253 |
+
q_idxs=[torch.cat([s[t==tgt] for (s,t)in indices]) for tgt in all_tgt]
|
| 254 |
+
path_idxs=torch.cat([torch.combinations(path_idx) for path_idx in path_idxs if len(path_idx)>1])
|
| 255 |
+
q_idxs=torch.cat([torch.combinations(q_idx) for q_idx in q_idxs if len(q_idx)>1])
|
| 256 |
+
|
| 257 |
+
return (path_idxs[:,0],q_idxs[:,0]),(path_idxs[:,1],q_idxs[:,1])
|
| 258 |
+
|
| 259 |
+
def _get_tgt_permutation_idx(self, indices):
|
| 260 |
+
# permute targets following indices
|
| 261 |
+
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
| 262 |
+
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
| 263 |
+
return batch_idx, tgt_idx
|
| 264 |
+
|
| 265 |
+
# *****************************************************************************
|
| 266 |
+
# >>> DETR Losses
|
| 267 |
+
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
|
| 268 |
+
loss_map = {
|
| 269 |
+
'labels': self.loss_labels,
|
| 270 |
+
'cardinality': self.loss_cardinality,
|
| 271 |
+
'boxes': self.loss_boxes
|
| 272 |
+
}
|
| 273 |
+
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
| 274 |
+
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
|
| 275 |
+
|
| 276 |
+
# >>> HOTR Losses
|
| 277 |
+
def get_HOI_loss(self, loss, outputs, targets, indices, num_boxes,use_consis, **kwargs):
|
| 278 |
+
loss_map = {
|
| 279 |
+
'pair_labels': self.loss_pair_labels,
|
| 280 |
+
'pair_actions': self.loss_pair_actions
|
| 281 |
+
}
|
| 282 |
+
if self.dataset_file == 'hico-det': loss_map['pair_targets'] = self.loss_pair_targets
|
| 283 |
+
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
| 284 |
+
return loss_map[loss](outputs, targets, indices, num_boxes,use_consis, **kwargs)
|
| 285 |
+
# *****************************************************************************
|
| 286 |
+
|
| 287 |
+
def forward(self, outputs, targets, log=False):
|
| 288 |
+
""" This performs the loss computation.
|
| 289 |
+
Parameters:
|
| 290 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
| 291 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
| 292 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
| 293 |
+
"""
|
| 294 |
+
outputs_without_aux = {k: v for k, v in outputs.items() if (k != 'aux_outputs' and k != 'hoi_aux_outputs')}
|
| 295 |
+
|
| 296 |
+
# Retrieve the matching between the outputs of the last layer and the targets
|
| 297 |
+
indices = self.matcher(outputs_without_aux, targets)
|
| 298 |
+
|
| 299 |
+
if self.HOI_losses is not None:
|
| 300 |
+
input_targets = [copy.deepcopy(target) for target in targets]
|
| 301 |
+
hoi_indices, hoi_targets = self.HOI_matcher(outputs_without_aux, input_targets, indices, log)
|
| 302 |
+
|
| 303 |
+
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
| 304 |
+
num_boxes = sum(len(t["labels"]) for t in targets)
|
| 305 |
+
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
| 306 |
+
if is_dist_avail_and_initialized():
|
| 307 |
+
torch.distributed.all_reduce(num_boxes)
|
| 308 |
+
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
|
| 309 |
+
|
| 310 |
+
# Compute all the requested losses
|
| 311 |
+
losses = {}
|
| 312 |
+
for loss in self.losses:
|
| 313 |
+
losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
|
| 314 |
+
|
| 315 |
+
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
| 316 |
+
if 'aux_outputs' in outputs:
|
| 317 |
+
for i, aux_outputs in enumerate(outputs['aux_outputs']):
|
| 318 |
+
indices = self.matcher(aux_outputs, targets)
|
| 319 |
+
for loss in self.losses:
|
| 320 |
+
if loss == 'masks':
|
| 321 |
+
# Intermediate masks losses are too costly to compute, we ignore them.
|
| 322 |
+
continue
|
| 323 |
+
kwargs = {}
|
| 324 |
+
if loss == 'labels':
|
| 325 |
+
# Logging is enabled only for the last layer
|
| 326 |
+
kwargs = {'log': False}
|
| 327 |
+
l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
|
| 328 |
+
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
|
| 329 |
+
losses.update(l_dict)
|
| 330 |
+
|
| 331 |
+
# HOI detection losses
|
| 332 |
+
if self.HOI_losses is not None:
|
| 333 |
+
for loss in self.HOI_losses:
|
| 334 |
+
losses.update(self.get_HOI_loss(loss, outputs, hoi_targets, hoi_indices, num_boxes,self.use_consis))
|
| 335 |
+
# if self.dataset_file == 'hico-det': losses['loss_oidx'] += losses['loss_tgt']
|
| 336 |
+
|
| 337 |
+
if 'hoi_aux_outputs' in outputs:
|
| 338 |
+
for i, aux_outputs in enumerate(outputs['hoi_aux_outputs']):
|
| 339 |
+
input_targets = [copy.deepcopy(target) for target in targets]
|
| 340 |
+
hoi_indices, targets_for_aux = self.HOI_matcher(aux_outputs, input_targets, indices, log)
|
| 341 |
+
for loss in self.HOI_losses:
|
| 342 |
+
kwargs = {}
|
| 343 |
+
if loss == 'pair_targets': kwargs = {'log': False} # Logging is enabled only for the last layer
|
| 344 |
+
l_dict = self.get_HOI_loss(loss, aux_outputs, hoi_targets, hoi_indices, num_boxes,self.use_consis, **kwargs)
|
| 345 |
+
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
|
| 346 |
+
losses.update(l_dict)
|
| 347 |
+
# if self.dataset_file == 'hico-det': losses[f'loss_oidx_{i}'] += losses[f'loss_tgt_{i}']
|
| 348 |
+
|
| 349 |
+
return losses
|
hotr/models/detr.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# HOTR official code : hotr/models/detr.py
|
| 3 |
+
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
|
| 4 |
+
# ------------------------------------------------------------------------
|
| 5 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 6 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 7 |
+
# ------------------------------------------------------------------------
|
| 8 |
+
"""
|
| 9 |
+
DETR & HOTR model and criterion classes.
|
| 10 |
+
"""
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch import nn
|
| 14 |
+
|
| 15 |
+
from hotr.util.misc import (NestedTensor, nested_tensor_from_tensor_list)
|
| 16 |
+
|
| 17 |
+
from .backbone import build_backbone
|
| 18 |
+
from .detr_matcher import build_matcher
|
| 19 |
+
from .hotr_matcher import build_hoi_matcher
|
| 20 |
+
from .transformer import build_transformer, build_hoi_transformer
|
| 21 |
+
from .criterion import SetCriterion
|
| 22 |
+
from .post_process import PostProcess
|
| 23 |
+
from .feed_forward import MLP
|
| 24 |
+
|
| 25 |
+
from .hotr import HOTR
|
| 26 |
+
from .hotr_v1 import HOTR_V1
|
| 27 |
+
|
| 28 |
+
class DETR(nn.Module):
|
| 29 |
+
""" This is the DETR module that performs object detection """
|
| 30 |
+
def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False):
|
| 31 |
+
""" Initializes the model.
|
| 32 |
+
Parameters:
|
| 33 |
+
backbone: torch module of the backbone to be used. See backbone.py
|
| 34 |
+
transformer: torch module of the transformer architecture. See transformer.py
|
| 35 |
+
num_classes: number of object classes
|
| 36 |
+
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
| 37 |
+
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
| 38 |
+
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
| 39 |
+
"""
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.num_queries = num_queries
|
| 42 |
+
self.transformer = transformer
|
| 43 |
+
hidden_dim = transformer.d_model
|
| 44 |
+
self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
|
| 45 |
+
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
|
| 46 |
+
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
| 47 |
+
self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)
|
| 48 |
+
self.backbone = backbone
|
| 49 |
+
self.aux_loss = aux_loss
|
| 50 |
+
|
| 51 |
+
def forward(self, samples: NestedTensor):
|
| 52 |
+
""" The forward expects a NestedTensor, which consists of:
|
| 53 |
+
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
|
| 54 |
+
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
|
| 55 |
+
It returns a dict with the following elements:
|
| 56 |
+
- "pred_logits": the classification logits (including no-object) for all queries.
|
| 57 |
+
Shape= [batch_size x num_queries x (num_classes + 1)]
|
| 58 |
+
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
|
| 59 |
+
(center_x, center_y, height, width). These values are normalized in [0, 1],
|
| 60 |
+
relative to the size of each individual image (disregarding possible padding).
|
| 61 |
+
See PostProcess for information on how to retrieve the unnormalized bounding box.
|
| 62 |
+
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
|
| 63 |
+
dictionnaries containing the two above keys for each decoder layer.
|
| 64 |
+
"""
|
| 65 |
+
if isinstance(samples, (list, torch.Tensor)):
|
| 66 |
+
samples = nested_tensor_from_tensor_list(samples)
|
| 67 |
+
features, pos = self.backbone(samples)
|
| 68 |
+
|
| 69 |
+
src, mask = features[-1].decompose()
|
| 70 |
+
assert mask is not None
|
| 71 |
+
hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
|
| 72 |
+
|
| 73 |
+
outputs_class = self.class_embed(hs)
|
| 74 |
+
outputs_coord = self.bbox_embed(hs).sigmoid()
|
| 75 |
+
out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
|
| 76 |
+
if self.aux_loss:
|
| 77 |
+
out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
|
| 78 |
+
|
| 79 |
+
return out
|
| 80 |
+
|
| 81 |
+
@torch.jit.unused
|
| 82 |
+
def _set_aux_loss(self, outputs_class, outputs_coord):
|
| 83 |
+
# this is a workaround to make torchscript happy, as torchscript
|
| 84 |
+
# doesn't support dictionary with non-homogeneous values, such
|
| 85 |
+
# as a dict having both a Tensor and a list.
|
| 86 |
+
return [{'pred_logits': a, 'pred_boxes': b}
|
| 87 |
+
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def build(args):
|
| 91 |
+
device = torch.device(args.device)
|
| 92 |
+
|
| 93 |
+
backbone = build_backbone(args)
|
| 94 |
+
|
| 95 |
+
transformer = build_transformer(args)
|
| 96 |
+
|
| 97 |
+
model = DETR(
|
| 98 |
+
backbone,
|
| 99 |
+
transformer,
|
| 100 |
+
num_classes=args.num_classes,
|
| 101 |
+
num_queries=args.num_queries,
|
| 102 |
+
aux_loss=args.aux_loss,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
matcher = build_matcher(args)
|
| 106 |
+
weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef}
|
| 107 |
+
weight_dict['loss_giou'] = args.giou_loss_coef
|
| 108 |
+
|
| 109 |
+
# TODO this is a hack
|
| 110 |
+
if args.aux_loss:
|
| 111 |
+
aux_weight_dict = {}
|
| 112 |
+
for i in range(args.dec_layers - 1):
|
| 113 |
+
aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
|
| 114 |
+
weight_dict.update(aux_weight_dict)
|
| 115 |
+
|
| 116 |
+
losses = ['labels', 'boxes', 'cardinality'] if args.frozen_weights is None else []
|
| 117 |
+
if args.HOIDet:
|
| 118 |
+
hoi_matcher = build_hoi_matcher(args)
|
| 119 |
+
hoi_losses = []
|
| 120 |
+
hoi_losses.append('pair_labels')
|
| 121 |
+
hoi_losses.append('pair_actions')
|
| 122 |
+
if args.dataset_file == 'hico-det': hoi_losses.append('pair_targets')
|
| 123 |
+
|
| 124 |
+
hoi_weight_dict={}
|
| 125 |
+
hoi_weight_dict['loss_hidx'] = args.hoi_idx_loss_coef
|
| 126 |
+
hoi_weight_dict['loss_oidx'] = args.hoi_idx_loss_coef
|
| 127 |
+
hoi_weight_dict['loss_h_consistency'] = args.hoi_idx_consistency_loss_coef
|
| 128 |
+
hoi_weight_dict['loss_o_consistency'] = args.hoi_idx_consistency_loss_coef
|
| 129 |
+
hoi_weight_dict['loss_act'] = args.hoi_act_loss_coef
|
| 130 |
+
hoi_weight_dict['loss_act_consistency'] = args.hoi_act_consistency_loss_coef
|
| 131 |
+
if args.dataset_file == 'hico-det':
|
| 132 |
+
hoi_weight_dict['loss_tgt'] = args.hoi_tgt_loss_coef
|
| 133 |
+
hoi_weight_dict['loss_tgt_consistency'] = args.hoi_tgt_consistency_loss_coef
|
| 134 |
+
if args.hoi_aux_loss:
|
| 135 |
+
hoi_aux_weight_dict = {}
|
| 136 |
+
for i in range(args.hoi_dec_layers):
|
| 137 |
+
hoi_aux_weight_dict.update({k + f'_{i}': v for k, v in hoi_weight_dict.items()})
|
| 138 |
+
hoi_weight_dict.update(hoi_aux_weight_dict)
|
| 139 |
+
|
| 140 |
+
criterion = SetCriterion(args.num_classes, matcher=matcher, weight_dict=hoi_weight_dict,
|
| 141 |
+
eos_coef=args.eos_coef, losses=losses, num_actions=args.num_actions,
|
| 142 |
+
HOI_losses=hoi_losses, HOI_matcher=hoi_matcher, args=args)
|
| 143 |
+
|
| 144 |
+
interaction_transformer = build_hoi_transformer(args) # if (args.share_enc and args.pretrained_dec) else None
|
| 145 |
+
|
| 146 |
+
kwargs = {}
|
| 147 |
+
if args.dataset_file == 'hico-det': kwargs['return_obj_class'] = args.valid_obj_ids
|
| 148 |
+
if args.sep_enc_forward:
|
| 149 |
+
model = HOTR_V1(
|
| 150 |
+
detr=model,
|
| 151 |
+
num_hoi_queries=args.num_hoi_queries,
|
| 152 |
+
num_actions=args.num_actions,
|
| 153 |
+
interaction_transformer=interaction_transformer,
|
| 154 |
+
augpath_name = args.augpath_name,
|
| 155 |
+
share_dec_param = args.share_dec_param,
|
| 156 |
+
stop_grad_stage = args.stop_grad_stage,
|
| 157 |
+
freeze_detr=(args.frozen_weights is not None),
|
| 158 |
+
share_enc=args.share_enc,
|
| 159 |
+
pretrained_dec=args.pretrained_dec,
|
| 160 |
+
temperature=args.temperature,
|
| 161 |
+
hoi_aux_loss=args.hoi_aux_loss,
|
| 162 |
+
**kwargs # only return verb class for HICO-DET dataset
|
| 163 |
+
)
|
| 164 |
+
else:
|
| 165 |
+
model = HOTR(
|
| 166 |
+
detr=model,
|
| 167 |
+
num_hoi_queries=args.num_hoi_queries,
|
| 168 |
+
num_actions=args.num_actions,
|
| 169 |
+
interaction_transformer=interaction_transformer,
|
| 170 |
+
augpath_name = args.augpath_name,
|
| 171 |
+
share_dec_param = args.share_dec_param,
|
| 172 |
+
stop_grad_stage = args.stop_grad_stage,
|
| 173 |
+
freeze_detr=(args.frozen_weights is not None),
|
| 174 |
+
share_enc=args.share_enc,
|
| 175 |
+
pretrained_dec=args.pretrained_dec,
|
| 176 |
+
temperature=args.temperature,
|
| 177 |
+
hoi_aux_loss=args.hoi_aux_loss,
|
| 178 |
+
**kwargs # only return verb class for HICO-DET dataset
|
| 179 |
+
)
|
| 180 |
+
postprocessors = {'hoi': PostProcess(args.HOIDet)}
|
| 181 |
+
else:
|
| 182 |
+
criterion = SetCriterion(args.num_classes, matcher=matcher, weight_dict=weight_dict,
|
| 183 |
+
eos_coef=args.eos_coef, losses=losses)
|
| 184 |
+
postprocessors = {'bbox': PostProcess(args.HOIDet)}
|
| 185 |
+
criterion.to(device)
|
| 186 |
+
|
| 187 |
+
return model, criterion, postprocessors
|
hotr/models/detr_matcher.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
"""
|
| 3 |
+
Modules to compute the matching cost and solve the corresponding LSAP.
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
from scipy.optimize import linear_sum_assignment
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
from hotr.util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class HungarianMatcher(nn.Module):
|
| 13 |
+
"""This class computes an assignment between the targets and the predictions of the network
|
| 14 |
+
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
|
| 15 |
+
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
|
| 16 |
+
while the others are un-matched (and thus treated as non-objects).
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
|
| 20 |
+
"""Creates the matcher
|
| 21 |
+
Params:
|
| 22 |
+
cost_class: This is the relative weight of the classification error in the matching cost
|
| 23 |
+
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
|
| 24 |
+
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
|
| 25 |
+
"""
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.cost_class = cost_class
|
| 28 |
+
self.cost_bbox = cost_bbox
|
| 29 |
+
self.cost_giou = cost_giou
|
| 30 |
+
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
|
| 31 |
+
|
| 32 |
+
@torch.no_grad()
|
| 33 |
+
def forward(self, outputs, targets):
|
| 34 |
+
""" Performs the matching
|
| 35 |
+
Params:
|
| 36 |
+
outputs: This is a dict that contains at least these entries:
|
| 37 |
+
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
| 38 |
+
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
|
| 39 |
+
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
|
| 40 |
+
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
|
| 41 |
+
objects in the target) containing the class labels
|
| 42 |
+
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
|
| 43 |
+
Returns:
|
| 44 |
+
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
| 45 |
+
- index_i is the indices of the selected predictions (in order)
|
| 46 |
+
- index_j is the indices of the corresponding selected targets (in order)
|
| 47 |
+
For each batch element, it holds:
|
| 48 |
+
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
| 49 |
+
"""
|
| 50 |
+
bs, num_queries = outputs["pred_logits"].shape[:2]
|
| 51 |
+
|
| 52 |
+
# We flatten to compute the cost matrices in a batch
|
| 53 |
+
out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
|
| 54 |
+
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
| 55 |
+
|
| 56 |
+
# Also concat the target labels and boxes
|
| 57 |
+
tgt_ids = torch.cat([v["labels"] for v in targets])
|
| 58 |
+
tgt_bbox = torch.cat([v["boxes"] for v in targets])
|
| 59 |
+
|
| 60 |
+
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
| 61 |
+
# but approximate it in 1 - proba[target class].
|
| 62 |
+
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
| 63 |
+
cost_class = -out_prob[:, tgt_ids]
|
| 64 |
+
|
| 65 |
+
# Compute the L1 cost between boxes
|
| 66 |
+
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
|
| 67 |
+
|
| 68 |
+
# Compute the giou cost betwen boxes
|
| 69 |
+
cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
|
| 70 |
+
|
| 71 |
+
# Final cost matrix
|
| 72 |
+
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
|
| 73 |
+
C = C.view(bs, num_queries, -1).cpu()
|
| 74 |
+
|
| 75 |
+
sizes = [len(v["boxes"]) for v in targets]
|
| 76 |
+
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
|
| 77 |
+
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def build_matcher(args):
|
| 81 |
+
return HungarianMatcher(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou)
|
hotr/models/feed_forward.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn.functional as F
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
class MLP(nn.Module):
|
| 5 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
| 6 |
+
|
| 7 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.num_layers = num_layers
|
| 10 |
+
h = [hidden_dim] * (num_layers - 1)
|
| 11 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
| 12 |
+
|
| 13 |
+
def forward(self, x):
|
| 14 |
+
for i, layer in enumerate(self.layers):
|
| 15 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
| 16 |
+
return x
|
hotr/models/hotr.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# HOTR official code : hotr/models/hotr.py
|
| 3 |
+
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
|
| 4 |
+
# ------------------------------------------------------------------------
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import copy
|
| 9 |
+
import time
|
| 10 |
+
import datetime
|
| 11 |
+
|
| 12 |
+
from hotr.util.misc import NestedTensor, nested_tensor_from_tensor_list
|
| 13 |
+
from .feed_forward import MLP
|
| 14 |
+
|
| 15 |
+
class HOTR(nn.Module):
|
| 16 |
+
def __init__(self, detr,
|
| 17 |
+
num_hoi_queries,
|
| 18 |
+
num_actions,
|
| 19 |
+
interaction_transformer,
|
| 20 |
+
augpath_name,
|
| 21 |
+
share_dec_param,
|
| 22 |
+
stop_grad_stage,
|
| 23 |
+
freeze_detr,
|
| 24 |
+
share_enc,
|
| 25 |
+
pretrained_dec,
|
| 26 |
+
temperature,
|
| 27 |
+
hoi_aux_loss,
|
| 28 |
+
return_obj_class=None):
|
| 29 |
+
super().__init__()
|
| 30 |
+
|
| 31 |
+
# * Instance Transformer ---------------
|
| 32 |
+
self.detr = detr
|
| 33 |
+
if freeze_detr:
|
| 34 |
+
# if this flag is given, freeze the object detection related parameters of DETR
|
| 35 |
+
for p in self.parameters():
|
| 36 |
+
p.requires_grad_(False)
|
| 37 |
+
hidden_dim = detr.transformer.d_model
|
| 38 |
+
# --------------------------------------
|
| 39 |
+
|
| 40 |
+
# * Interaction Transformer -----------------------------------------
|
| 41 |
+
self.num_queries = num_hoi_queries
|
| 42 |
+
self.query_embed = nn.Embedding(self.num_queries, hidden_dim)
|
| 43 |
+
self.H_Pointer_embed = MLP(hidden_dim, hidden_dim, hidden_dim, 3)
|
| 44 |
+
self.O_Pointer_embed = MLP(hidden_dim, hidden_dim, hidden_dim, 3)
|
| 45 |
+
self.action_embed = nn.Linear(hidden_dim, num_actions+1)
|
| 46 |
+
# --------------------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# * HICO-DET FFN heads ---------------------------------------------
|
| 50 |
+
self.return_obj_class = (return_obj_class is not None)
|
| 51 |
+
if return_obj_class: self._valid_obj_ids = return_obj_class + [return_obj_class[-1]+1]
|
| 52 |
+
# ------------------------------------------------------------------
|
| 53 |
+
# * Transformer Options ---------------------------------------------
|
| 54 |
+
self.interaction_transformer = interaction_transformer
|
| 55 |
+
|
| 56 |
+
if share_enc: # share encoder
|
| 57 |
+
self.interaction_transformer.encoder = detr.transformer.encoder
|
| 58 |
+
|
| 59 |
+
if pretrained_dec: # free variables for interaction decoder
|
| 60 |
+
self.interaction_transformer.decoder = copy.deepcopy(detr.transformer.decoder)
|
| 61 |
+
for p in self.interaction_transformer.decoder.parameters():
|
| 62 |
+
p.requires_grad_(True)
|
| 63 |
+
# ---------------------------------------------------------------------
|
| 64 |
+
#Augmented paths
|
| 65 |
+
|
| 66 |
+
self.aug_paths = augpath_name
|
| 67 |
+
|
| 68 |
+
if 'p2' in augpath_name:
|
| 69 |
+
if not share_dec_param:
|
| 70 |
+
self.xtoHO_interaction_decoder = copy.deepcopy(self.interaction_transformer.decoder)
|
| 71 |
+
self.HOtoI_interaction_decoder = copy.deepcopy(self.interaction_transformer.decoder)
|
| 72 |
+
else:
|
| 73 |
+
self.xtoHO_interaction_decoder = self.interaction_transformer.decoder
|
| 74 |
+
self.HOtoI_interaction_decoder = self.interaction_transformer.decoder
|
| 75 |
+
|
| 76 |
+
self.query_embed_HOtoI = nn.Embedding(self.num_queries, hidden_dim)
|
| 77 |
+
self.query_embed_HOtoI2 = nn.Embedding(self.num_queries, hidden_dim)
|
| 78 |
+
self.H_Pointer_embed_HOtoI = MLP(hidden_dim, hidden_dim, hidden_dim, 3)
|
| 79 |
+
self.O_Pointer_embed_HOtoI = MLP(hidden_dim, hidden_dim, hidden_dim, 3)
|
| 80 |
+
self.action_embed_HOtoI = nn.Linear(hidden_dim, num_actions+1)
|
| 81 |
+
|
| 82 |
+
if 'p3' in augpath_name:
|
| 83 |
+
if not share_dec_param:
|
| 84 |
+
self.xtoHI_interaction_decoder = copy.deepcopy(self.interaction_transformer.decoder)
|
| 85 |
+
self.HItoO_interaction_decoder = copy.deepcopy(self.interaction_transformer.decoder)
|
| 86 |
+
else:
|
| 87 |
+
self.xtoHI_interaction_decoder = self.interaction_transformer.decoder
|
| 88 |
+
self.HItoO_interaction_decoder = self.interaction_transformer.decoder
|
| 89 |
+
|
| 90 |
+
self.query_embed_HItoO = nn.Embedding(self.num_queries, hidden_dim)
|
| 91 |
+
self.query_embed_HItoO2 = nn.Embedding(self.num_queries, hidden_dim)
|
| 92 |
+
self.H_Pointer_embed_HItoO = MLP(hidden_dim, hidden_dim, hidden_dim, 3)
|
| 93 |
+
self.O_Pointer_embed_HItoO = MLP(hidden_dim, hidden_dim, hidden_dim, 3)
|
| 94 |
+
self.action_embed_HItoO = nn.Linear(hidden_dim, num_actions+1)
|
| 95 |
+
|
| 96 |
+
if 'p4' in augpath_name:
|
| 97 |
+
if not share_dec_param:
|
| 98 |
+
self.xtoOI_interaction_decoder = copy.deepcopy(self.interaction_transformer.decoder)
|
| 99 |
+
self.OItoH_interaction_decoder = copy.deepcopy(self.interaction_transformer.decoder)
|
| 100 |
+
else:
|
| 101 |
+
self.xtoOI_interaction_decoder = self.interaction_transformer.decoder
|
| 102 |
+
self.OItoH_interaction_decoder = self.interaction_transformer.decoder
|
| 103 |
+
|
| 104 |
+
self.query_embed_OItoH = nn.Embedding(self.num_queries, hidden_dim)
|
| 105 |
+
self.query_embed_OItoH2 = nn.Embedding(self.num_queries, hidden_dim)
|
| 106 |
+
self.H_Pointer_embed_OItoH = MLP(hidden_dim, hidden_dim, hidden_dim, 3)
|
| 107 |
+
self.O_Pointer_embed_OItoH = MLP(hidden_dim, hidden_dim, hidden_dim, 3)
|
| 108 |
+
self.action_embed_OItoH = nn.Linear(hidden_dim, num_actions+1)
|
| 109 |
+
|
| 110 |
+
self.stop_grad_stage = stop_grad_stage
|
| 111 |
+
|
| 112 |
+
# * Loss Options -------------------
|
| 113 |
+
self.tau = temperature
|
| 114 |
+
self.hoi_aux_loss = hoi_aux_loss
|
| 115 |
+
# ----------------------------------
|
| 116 |
+
|
| 117 |
+
def forward(self, samples: NestedTensor):
|
| 118 |
+
if isinstance(samples, (list, torch.Tensor)):
|
| 119 |
+
samples = nested_tensor_from_tensor_list(samples)
|
| 120 |
+
|
| 121 |
+
# >>>>>>>>>>>> BACKBONE LAYERS <<<<<<<<<<<<<<<
|
| 122 |
+
features, pos = self.detr.backbone(samples)
|
| 123 |
+
bs = features[-1].tensors.shape[0]
|
| 124 |
+
src, mask = features[-1].decompose()
|
| 125 |
+
assert mask is not None
|
| 126 |
+
# ----------------------------------------------
|
| 127 |
+
|
| 128 |
+
# >>>>>>>>>>>> OBJECT DETECTION LAYERS <<<<<<<<<<
|
| 129 |
+
start_time = time.time()
|
| 130 |
+
hs, memory = self.detr.transformer(self.detr.input_proj(src), mask, self.detr.query_embed.weight, pos[-1])
|
| 131 |
+
inst_repr = F.normalize(hs[-1], p=2, dim=2) # instance representations
|
| 132 |
+
|
| 133 |
+
# Prediction Heads for Object Detection
|
| 134 |
+
outputs_class = self.detr.class_embed(hs)
|
| 135 |
+
outputs_coord = self.detr.bbox_embed(hs).sigmoid()
|
| 136 |
+
object_detection_time = time.time() - start_time
|
| 137 |
+
# -----------------------------------------------
|
| 138 |
+
|
| 139 |
+
# >>>>>>>>>>>> HOI DETECTION LAYERS <<<<<<<<<<<<<<<
|
| 140 |
+
start_time = time.time()
|
| 141 |
+
assert hasattr(self, 'interaction_transformer'), "Missing Interaction Transformer."
|
| 142 |
+
H_Pointer_reprs_bag,O_Pointer_reprs_bag,outputs_action=[],[],[]
|
| 143 |
+
# main path P1
|
| 144 |
+
interaction_hs= self.interaction_transformer(self.detr.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] # interaction representations
|
| 145 |
+
H_Pointer_reprs_bag.append(F.normalize(self.H_Pointer_embed(interaction_hs), p=2, dim=-1))
|
| 146 |
+
O_Pointer_reprs_bag.append(F.normalize(self.O_Pointer_embed(interaction_hs), p=2, dim=-1))
|
| 147 |
+
outputs_action.append(self.action_embed(interaction_hs))
|
| 148 |
+
|
| 149 |
+
if len(self.aug_paths)!=0:
|
| 150 |
+
pos_aug = pos[-1].flatten(2).permute(2, 0, 1)
|
| 151 |
+
mask_aug = mask.flatten(1)
|
| 152 |
+
|
| 153 |
+
# P2 (x->HO->I)
|
| 154 |
+
if 'p2' in self.aug_paths:
|
| 155 |
+
tgt_2 = torch.zeros_like(self.query_embed_HOtoI.weight.unsqueeze(1).repeat(1, bs, 1))
|
| 156 |
+
hs_HOtoI = self.xtoHO_interaction_decoder(tgt_2,memory,memory_key_padding_mask=mask_aug, pos=pos_aug, query_pos=self.query_embed_HOtoI.weight.unsqueeze(1).repeat(1, bs, 1)).transpose(1,2)
|
| 157 |
+
tgt_HOtoI = hs_HOtoI.transpose(1,2)[-1] if not self.stop_grad_stage else hs_HOtoI.clone().detach().transpose(1,2)[-1]
|
| 158 |
+
hs2_HOtoI = self.HOtoI_interaction_decoder(tgt_HOtoI,memory,memory_key_padding_mask=mask_aug, pos=pos_aug, query_pos=self.query_embed_HOtoI2.weight.unsqueeze(1).repeat(1, bs, 1)).transpose(1,2)
|
| 159 |
+
H_Pointer_reprs_bag.append(F.normalize(self.H_Pointer_embed_HOtoI(hs_HOtoI), p=2, dim=-1))
|
| 160 |
+
O_Pointer_reprs_bag.append(F.normalize(self.O_Pointer_embed_HOtoI(hs_HOtoI), p=2, dim=-1))
|
| 161 |
+
outputs_action.append(self.action_embed_HOtoI(hs2_HOtoI))
|
| 162 |
+
# P3 (x->HI->O)
|
| 163 |
+
if 'p3' in self.aug_paths:
|
| 164 |
+
tgt_3 = torch.zeros_like(self.query_embed_HItoO.weight.unsqueeze(1).repeat(1, bs, 1))
|
| 165 |
+
hs_HItoO = self.xtoHI_interaction_decoder(tgt_3,memory,memory_key_padding_mask=mask_aug, pos=pos_aug, query_pos=self.query_embed_HItoO.weight.unsqueeze(1).repeat(1, bs, 1)).transpose(1,2)
|
| 166 |
+
tgt_HItoO = hs_HItoO.transpose(1,2)[-1] if not self.stop_grad_stage else hs_HItoO.clone().detach().transpose(1,2)[-1]
|
| 167 |
+
hs2_HItoO = self.HItoO_interaction_decoder(tgt_HItoO,memory,memory_key_padding_mask=mask_aug, pos=pos_aug, query_pos=self.query_embed_HItoO2.weight.unsqueeze(1).repeat(1, bs, 1)).transpose(1,2)
|
| 168 |
+
H_Pointer_reprs_bag.append(F.normalize(self.H_Pointer_embed_HItoO(hs_HItoO), p=2, dim=-1))
|
| 169 |
+
O_Pointer_reprs_bag.append(F.normalize(self.O_Pointer_embed_HItoO(hs2_HItoO), p=2, dim=-1))
|
| 170 |
+
outputs_action.append(self.action_embed_HItoO(hs_HItoO))
|
| 171 |
+
# P4 (x->OI->H)
|
| 172 |
+
if 'p4' in self.aug_paths:
|
| 173 |
+
tgt_4 = torch.zeros_like(self.query_embed_OItoH.weight.unsqueeze(1).repeat(1, bs, 1))
|
| 174 |
+
hs_OItoH = self.xtoOI_interaction_decoder(tgt_3,memory,memory_key_padding_mask=mask_aug, pos=pos_aug, query_pos=self.query_embed_OItoH.weight.unsqueeze(1).repeat(1, bs, 1)).transpose(1,2)
|
| 175 |
+
tgt_OItoH = hs_OItoH.transpose(1,2)[-1] if not self.stop_grad_stage else hs_OItoH.clone().detach().transpose(1,2)[-1]
|
| 176 |
+
hs2_OItoH = self.OItoH_interaction_decoder(tgt_OItoH,memory,memory_key_padding_mask=mask_aug, pos=pos_aug, query_pos=self.query_embed_OItoH2.weight.unsqueeze(1).repeat(1, bs, 1)).transpose(1,2)
|
| 177 |
+
H_Pointer_reprs_bag.append(F.normalize(self.H_Pointer_embed_OItoH(hs2_OItoH), p=2, dim=-1))
|
| 178 |
+
O_Pointer_reprs_bag.append(F.normalize(self.O_Pointer_embed_OItoH(hs_OItoH), p=2, dim=-1))
|
| 179 |
+
outputs_action.append(self.action_embed_OItoH(hs_OItoH))
|
| 180 |
+
|
| 181 |
+
inst_repr_all=inst_repr.transpose(1,2).repeat(1+len(self.aug_paths),1,1)
|
| 182 |
+
|
| 183 |
+
H_Pointer_reprs_bag=torch.cat(H_Pointer_reprs_bag,1)
|
| 184 |
+
O_Pointer_reprs_bag=torch.cat(O_Pointer_reprs_bag,1)
|
| 185 |
+
|
| 186 |
+
outputs_hidx = [(torch.bmm(H_Pointer_repr, inst_repr_all)) / self.tau for H_Pointer_repr in H_Pointer_reprs_bag] #(dec_layer,(1+len(aug))*bs,dec_q,hidden_dim)
|
| 187 |
+
outputs_oidx = [(torch.bmm(O_Pointer_repr, inst_repr_all)) / self.tau for O_Pointer_repr in O_Pointer_reprs_bag]
|
| 188 |
+
|
| 189 |
+
outputs_action=torch.stack(outputs_action,dim=2) #(dec_layer,bs,1+#aug,dec_q,#action)
|
| 190 |
+
|
| 191 |
+
# --------------------------------------------------
|
| 192 |
+
hoi_detection_time = time.time() - start_time
|
| 193 |
+
hoi_recognition_time = max(hoi_detection_time - object_detection_time, 0)
|
| 194 |
+
# -------------------------------------------------------------------
|
| 195 |
+
|
| 196 |
+
# [Target Classification]
|
| 197 |
+
if self.return_obj_class:
|
| 198 |
+
detr_logits = outputs_class[-1, ..., self._valid_obj_ids]
|
| 199 |
+
o_indices = [output_oidx.max(-1)[-1].view(1+len(self.aug_paths),bs,self.num_queries).transpose(0,1) for output_oidx in outputs_oidx]
|
| 200 |
+
obj_logit_stack = [torch.stack([detr_logits[batch_, o_idx, :] for batch_, o_idc in enumerate(o_indice) for o_idx in o_idc], 0) for o_indice in o_indices]
|
| 201 |
+
outputs_obj_class = obj_logit_stack
|
| 202 |
+
|
| 203 |
+
out = {
|
| 204 |
+
"pred_logits": outputs_class[-1],
|
| 205 |
+
"pred_boxes": outputs_coord[-1],
|
| 206 |
+
"pred_hidx": outputs_hidx[-1],
|
| 207 |
+
"pred_oidx": outputs_oidx[-1],
|
| 208 |
+
"pred_actions": outputs_action[-1],
|
| 209 |
+
"hoi_recognition_time": hoi_recognition_time,
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
if self.return_obj_class: out["pred_obj_logits"] = outputs_obj_class[-1]
|
| 213 |
+
# import pdb;pdb.set_trace()
|
| 214 |
+
if self.hoi_aux_loss: # auxiliary loss
|
| 215 |
+
out['hoi_aux_outputs'] = \
|
| 216 |
+
self._set_aux_loss_with_tgt(outputs_class, outputs_coord, outputs_hidx, outputs_oidx, outputs_action, outputs_obj_class) \
|
| 217 |
+
if self.return_obj_class else \
|
| 218 |
+
self._set_aux_loss(outputs_class, outputs_coord, outputs_hidx, outputs_oidx, outputs_action)
|
| 219 |
+
|
| 220 |
+
return out
|
| 221 |
+
|
| 222 |
+
@torch.jit.unused
|
| 223 |
+
def _set_aux_loss(self, outputs_class, outputs_coord, outputs_hidx, outputs_oidx, outputs_action):
|
| 224 |
+
return [{'pred_logits': a, 'pred_boxes': b, 'pred_hidx': c, 'pred_oidx': d, 'pred_actions': e}
|
| 225 |
+
for a, b, c, d, e in zip(
|
| 226 |
+
outputs_class[-1:].repeat((outputs_action.shape[0], 1, 1, 1)),
|
| 227 |
+
outputs_coord[-1:].repeat((outputs_action.shape[0], 1, 1, 1)),
|
| 228 |
+
outputs_hidx[:-1],
|
| 229 |
+
outputs_oidx[:-1],
|
| 230 |
+
outputs_action[:-1])]
|
| 231 |
+
|
| 232 |
+
@torch.jit.unused
|
| 233 |
+
def _set_aux_loss_with_tgt(self, outputs_class, outputs_coord, outputs_hidx, outputs_oidx, outputs_action, outputs_tgt):
|
| 234 |
+
return [{'pred_logits': a, 'pred_boxes': b, 'pred_hidx': c, 'pred_oidx': d, 'pred_actions': e, 'pred_obj_logits': f}
|
| 235 |
+
for a, b, c, d, e, f in zip(
|
| 236 |
+
outputs_class[-1:].repeat((outputs_action.shape[0], 1, 1, 1)),
|
| 237 |
+
outputs_coord[-1:].repeat((outputs_action.shape[0], 1, 1, 1)),
|
| 238 |
+
outputs_hidx[:-1],
|
| 239 |
+
outputs_oidx[:-1],
|
| 240 |
+
outputs_action[:-1],
|
| 241 |
+
outputs_tgt[:-1])]
|
hotr/models/hotr_matcher.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# HOTR official code : hotr/models/hotr_matcher.py
|
| 3 |
+
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
|
| 4 |
+
# ------------------------------------------------------------------------
|
| 5 |
+
import torch
|
| 6 |
+
from scipy.optimize import linear_sum_assignment
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
from hotr.util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
|
| 10 |
+
|
| 11 |
+
import hotr.util.misc as utils
|
| 12 |
+
import wandb
|
| 13 |
+
|
| 14 |
+
class HungarianPairMatcher(nn.Module):
|
| 15 |
+
def __init__(self, args):
|
| 16 |
+
"""Creates the matcher
|
| 17 |
+
Params:
|
| 18 |
+
cost_action: This is the relative weight of the multi-label action classification error in the matching cost
|
| 19 |
+
cost_hbox: This is the relative weight of the classification error for human idx in the matching cost
|
| 20 |
+
cost_obox: This is the relative weight of the classification error for object idx in the matching cost
|
| 21 |
+
"""
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.cost_action = args.set_cost_act
|
| 24 |
+
self.cost_hbox = self.cost_obox = args.set_cost_idx
|
| 25 |
+
self.cost_target = args.set_cost_tgt
|
| 26 |
+
self.log_printer = args.wandb
|
| 27 |
+
self.is_vcoco = (args.dataset_file == 'vcoco')
|
| 28 |
+
self.is_hico = (args.dataset_file == 'hico-det')
|
| 29 |
+
if self.is_vcoco:
|
| 30 |
+
self.valid_ids = args.valid_ids
|
| 31 |
+
self.invalid_ids = args.invalid_ids
|
| 32 |
+
assert self.cost_action != 0 or self.cost_hbox != 0 or self.cost_obox != 0, "all costs cant be 0"
|
| 33 |
+
|
| 34 |
+
def reduce_redundant_gt_box(self, tgt_bbox, indices):
|
| 35 |
+
"""Filters redundant Ground-Truth Bounding Boxes
|
| 36 |
+
Due to random crop augmentation, there exists cases where there exists
|
| 37 |
+
multiple redundant labels for the exact same bounding box and object class.
|
| 38 |
+
This function deals with the redundant labels for smoother HOTR training.
|
| 39 |
+
"""
|
| 40 |
+
tgt_bbox_unique, map_idx, idx_cnt = torch.unique(tgt_bbox, dim=0, return_inverse=True, return_counts=True)
|
| 41 |
+
|
| 42 |
+
k_idx, bbox_idx = indices
|
| 43 |
+
|
| 44 |
+
triggered = False
|
| 45 |
+
if (len(tgt_bbox) != len(tgt_bbox_unique)):
|
| 46 |
+
map_dict = {k: v for k, v in enumerate(map_idx)}
|
| 47 |
+
map_bbox2kidx = {int(bbox_id): k_id for bbox_id, k_id in zip(bbox_idx, k_idx)}
|
| 48 |
+
|
| 49 |
+
bbox_lst, k_lst = [], []
|
| 50 |
+
for bbox_id in bbox_idx:
|
| 51 |
+
if map_dict[int(bbox_id)] not in bbox_lst:
|
| 52 |
+
bbox_lst.append(map_dict[int(bbox_id)])
|
| 53 |
+
k_lst.append(map_bbox2kidx[int(bbox_id)])
|
| 54 |
+
bbox_idx = torch.tensor(bbox_lst)
|
| 55 |
+
k_idx = torch.tensor(k_lst)
|
| 56 |
+
tgt_bbox_res = tgt_bbox_unique
|
| 57 |
+
else:
|
| 58 |
+
tgt_bbox_res = tgt_bbox
|
| 59 |
+
|
| 60 |
+
bbox_idx = bbox_idx.to(tgt_bbox.device)
|
| 61 |
+
|
| 62 |
+
return tgt_bbox_res, k_idx, bbox_idx
|
| 63 |
+
|
| 64 |
+
@torch.no_grad()
|
| 65 |
+
def forward(self, outputs, targets, indices, log=False):
|
| 66 |
+
assert "pred_actions" in outputs, "There is no action output for pair matching"
|
| 67 |
+
num_obj_queries = outputs["pred_boxes"].shape[1]
|
| 68 |
+
bs,num_path, num_queries = outputs["pred_actions"].shape[:3]
|
| 69 |
+
detr_query_num = outputs["pred_logits"].shape[1] \
|
| 70 |
+
if (outputs["pred_oidx"].shape[-1] == (outputs["pred_logits"].shape[1] + 1)) else -1
|
| 71 |
+
|
| 72 |
+
return_list = []
|
| 73 |
+
if self.log_printer and log:
|
| 74 |
+
log_dict = {'h_cost': [], 'o_cost': [], 'act_cost': []}
|
| 75 |
+
if self.is_hico: log_dict['tgt_cost'] = []
|
| 76 |
+
|
| 77 |
+
for batch_idx in range(bs):
|
| 78 |
+
tgt_bbox = targets[batch_idx]["boxes"] # (num_boxes, 4)
|
| 79 |
+
tgt_cls = targets[batch_idx]["labels"] # (num_boxes)
|
| 80 |
+
|
| 81 |
+
if self.is_vcoco:
|
| 82 |
+
targets[batch_idx]["pair_actions"][:, self.invalid_ids] = 0
|
| 83 |
+
keep_idx = (targets[batch_idx]["pair_actions"].sum(dim=-1) != 0)
|
| 84 |
+
targets[batch_idx]["pair_boxes"] = targets[batch_idx]["pair_boxes"][keep_idx]
|
| 85 |
+
targets[batch_idx]["pair_actions"] = targets[batch_idx]["pair_actions"][keep_idx]
|
| 86 |
+
targets[batch_idx]["pair_targets"] = targets[batch_idx]["pair_targets"][keep_idx]
|
| 87 |
+
|
| 88 |
+
tgt_pbox = targets[batch_idx]["pair_boxes"] # (num_pair_boxes, 8)
|
| 89 |
+
tgt_act = targets[batch_idx]["pair_actions"] # (num_pair_boxes, 29)
|
| 90 |
+
tgt_tgt = targets[batch_idx]["pair_targets"] # (num_pair_boxes)
|
| 91 |
+
|
| 92 |
+
tgt_hbox = tgt_pbox[:, :4] # (num_pair_boxes, 4)
|
| 93 |
+
tgt_obox = tgt_pbox[:, 4:] # (num_pair_boxes, 4)
|
| 94 |
+
elif self.is_hico:
|
| 95 |
+
tgt_act = targets[batch_idx]["pair_actions"] # (num_pair_boxes, 117)
|
| 96 |
+
tgt_tgt = targets[batch_idx]["pair_targets"] # (num_pair_boxes)
|
| 97 |
+
|
| 98 |
+
tgt_hbox = targets[batch_idx]["sub_boxes"] # (num_pair_boxes, 4)
|
| 99 |
+
tgt_obox = targets[batch_idx]["obj_boxes"] # (num_pair_boxes, 4)
|
| 100 |
+
|
| 101 |
+
# find which gt boxes match the h, o boxes in the pair
|
| 102 |
+
if self.is_vcoco:
|
| 103 |
+
hbox_with_cls = torch.cat([tgt_hbox, torch.ones((tgt_hbox.shape[0], 1)).to(tgt_hbox.device)], dim=1)
|
| 104 |
+
elif self.is_hico:
|
| 105 |
+
hbox_with_cls = torch.cat([tgt_hbox, torch.zeros((tgt_hbox.shape[0], 1)).to(tgt_hbox.device)], dim=1)
|
| 106 |
+
obox_with_cls = torch.cat([tgt_obox, tgt_tgt.unsqueeze(-1)], dim=1)
|
| 107 |
+
obox_with_cls[obox_with_cls[:, :4].sum(dim=1) == -4, -1] = -1 # turn the class of occluded objects to -1
|
| 108 |
+
|
| 109 |
+
bbox_with_cls = torch.cat([tgt_bbox, tgt_cls.unsqueeze(-1)], dim=1)
|
| 110 |
+
bbox_with_cls, k_idx, bbox_idx = self.reduce_redundant_gt_box(bbox_with_cls, indices[batch_idx])
|
| 111 |
+
bbox_with_cls = torch.cat((bbox_with_cls, torch.as_tensor([-1.]*5).unsqueeze(0).to(tgt_cls.device)), dim=0)
|
| 112 |
+
|
| 113 |
+
cost_hbox = torch.cdist(hbox_with_cls, bbox_with_cls, p=1)
|
| 114 |
+
cost_obox = torch.cdist(obox_with_cls, bbox_with_cls, p=1)
|
| 115 |
+
|
| 116 |
+
# find which gt boxes matches which prediction in K
|
| 117 |
+
h_match_indices = torch.nonzero(cost_hbox == 0, as_tuple=False) # (num_hbox, num_boxes)
|
| 118 |
+
o_match_indices = torch.nonzero(cost_obox == 0, as_tuple=False) # (num_obox, num_boxes)
|
| 119 |
+
|
| 120 |
+
tgt_hids, tgt_oids = [], []
|
| 121 |
+
|
| 122 |
+
# obtain ground truth indices for h
|
| 123 |
+
if len(h_match_indices) != len(o_match_indices):
|
| 124 |
+
import pdb; pdb.set_trace()
|
| 125 |
+
|
| 126 |
+
for h_match_idx, o_match_idx in zip(h_match_indices, o_match_indices):
|
| 127 |
+
hbox_idx, H_bbox_idx = h_match_idx
|
| 128 |
+
obox_idx, O_bbox_idx = o_match_idx
|
| 129 |
+
if O_bbox_idx == (len(bbox_with_cls)-1): # if the object class is -1
|
| 130 |
+
O_bbox_idx = H_bbox_idx # happens in V-COCO, the target object may not appear
|
| 131 |
+
|
| 132 |
+
GT_idx_for_H = (bbox_idx == H_bbox_idx).nonzero(as_tuple=False).squeeze(-1)
|
| 133 |
+
query_idx_for_H = k_idx[GT_idx_for_H]
|
| 134 |
+
tgt_hids.append(query_idx_for_H)
|
| 135 |
+
|
| 136 |
+
GT_idx_for_O = (bbox_idx == O_bbox_idx).nonzero(as_tuple=False).squeeze(-1)
|
| 137 |
+
query_idx_for_O = k_idx[GT_idx_for_O]
|
| 138 |
+
tgt_oids.append(query_idx_for_O)
|
| 139 |
+
|
| 140 |
+
# check if empty
|
| 141 |
+
if len(tgt_hids) == 0: tgt_hids.append(torch.as_tensor([-1])) # we later ignore the label -1
|
| 142 |
+
if len(tgt_oids) == 0: tgt_oids.append(torch.as_tensor([-1])) # we later ignore the label -1
|
| 143 |
+
|
| 144 |
+
tgt_sum = (tgt_act.sum(dim=-1)).unsqueeze(0)
|
| 145 |
+
flag = False
|
| 146 |
+
if tgt_act.shape[0] == 0:
|
| 147 |
+
tgt_act = torch.zeros((1, tgt_act.shape[1])).to(targets[batch_idx]["pair_actions"].device)
|
| 148 |
+
targets[batch_idx]["pair_actions"] = torch.zeros((1, targets[batch_idx]["pair_actions"].shape[1])).to(targets[batch_idx]["pair_actions"].device)
|
| 149 |
+
if self.is_hico:
|
| 150 |
+
pad_tgt = -1 # outputs["pred_obj_logits"].shape[-1]-1
|
| 151 |
+
tgt_tgt = torch.tensor([pad_tgt]).to(targets[batch_idx]["pair_targets"])
|
| 152 |
+
targets[batch_idx]["pair_targets"] = torch.tensor([pad_tgt]).to(targets[batch_idx]["pair_targets"].device)
|
| 153 |
+
tgt_sum = (tgt_act.sum(dim=-1) + 1).unsqueeze(0)
|
| 154 |
+
|
| 155 |
+
# Concat target label
|
| 156 |
+
tgt_hids = torch.cat(tgt_hids).repeat(num_path)
|
| 157 |
+
tgt_oids = torch.cat(tgt_oids).repeat(num_path)
|
| 158 |
+
# import pdb;pdb.set_trace()
|
| 159 |
+
outputs_hidx=outputs["pred_hidx"].view(num_path,bs,num_queries,-1).transpose(0,1).flatten(1,2)
|
| 160 |
+
outputs_oidx=outputs["pred_oidx"].view(num_path,bs,num_queries,-1).transpose(0,1).flatten(1,2)
|
| 161 |
+
|
| 162 |
+
outputs_action=outputs["pred_actions"].view(bs,num_path*num_queries,-1)
|
| 163 |
+
out_hprob = outputs_hidx[batch_idx].softmax(-1)
|
| 164 |
+
out_oprob = outputs_oidx[batch_idx].softmax(-1)
|
| 165 |
+
out_act = outputs_action[batch_idx].clone()
|
| 166 |
+
if self.is_vcoco: out_act[..., self.invalid_ids] = 0
|
| 167 |
+
if self.is_hico:
|
| 168 |
+
outputs_obj_logits=outputs["pred_obj_logits"].view(bs,num_path,num_queries,-1).view(bs,num_path*num_queries,-1)
|
| 169 |
+
out_tgt = outputs_obj_logits[batch_idx].softmax(-1)
|
| 170 |
+
out_tgt[..., -1] = 0 # don't get cost for no-object
|
| 171 |
+
|
| 172 |
+
tgt_act = torch.cat([tgt_act, torch.zeros(tgt_act.shape[0]).unsqueeze(-1).to(tgt_act.device)], dim=-1).repeat(num_path,1)
|
| 173 |
+
|
| 174 |
+
cost_hclass = -out_hprob[:, tgt_hids] # [batch_size * num_queries, detr.num_queries+1]
|
| 175 |
+
cost_oclass = -out_oprob[:, tgt_oids] # [batch_size * num_queries, detr.num_queries+1]
|
| 176 |
+
# import pdb;pdb.set_trace()
|
| 177 |
+
cost_pos_act = (-torch.matmul(out_act, tgt_act.t().float())) / tgt_sum.repeat(1,num_path)
|
| 178 |
+
cost_neg_act = (torch.matmul(out_act, (~tgt_act.bool()).type(torch.int64).t().float())) / (~tgt_act.bool()).type(torch.int64).sum(dim=-1).unsqueeze(0)
|
| 179 |
+
cost_action = cost_pos_act + cost_neg_act
|
| 180 |
+
|
| 181 |
+
h_cost = self.cost_hbox * cost_hclass
|
| 182 |
+
o_cost = self.cost_obox * cost_oclass
|
| 183 |
+
|
| 184 |
+
act_cost = self.cost_action * cost_action
|
| 185 |
+
|
| 186 |
+
C = h_cost + o_cost + act_cost
|
| 187 |
+
|
| 188 |
+
if self.is_hico:
|
| 189 |
+
cost_target = -out_tgt[:, tgt_tgt.repeat(num_path)]
|
| 190 |
+
tgt_cost = self.cost_target * cost_target
|
| 191 |
+
C += tgt_cost
|
| 192 |
+
C = C.view(num_path,num_queries, -1).cpu()
|
| 193 |
+
|
| 194 |
+
sizes = [len(tgt_hids)//num_path]*num_path
|
| 195 |
+
hoi_indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
|
| 196 |
+
return_list.append([(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in hoi_indices])
|
| 197 |
+
# import pdb;pdb.set_trace()
|
| 198 |
+
targets[batch_idx]["h_labels"] = tgt_hids.to(tgt_hbox.device)
|
| 199 |
+
targets[batch_idx]["o_labels"] = tgt_oids.to(tgt_obox.device)
|
| 200 |
+
log_act_cost = torch.zeros([1]).to(tgt_act.device) if tgt_act.shape[0] == 0 else act_cost.min(dim=0)[0].mean()
|
| 201 |
+
|
| 202 |
+
if self.log_printer and log:
|
| 203 |
+
log_dict['h_cost'].append(h_cost[:num_queries].min(dim=0)[0].mean())
|
| 204 |
+
log_dict['o_cost'].append(o_cost[:num_queries].min(dim=0)[0].mean())
|
| 205 |
+
log_dict['act_cost'].append(act_cost[:num_queries].min(dim=0)[0].mean())
|
| 206 |
+
if self.is_hico: log_dict['tgt_cost'].append(tgt_cost[:num_queries].min(dim=0)[0].mean())
|
| 207 |
+
if self.log_printer and log:
|
| 208 |
+
log_dict['h_cost'] = torch.stack(log_dict['h_cost']).mean()
|
| 209 |
+
log_dict['o_cost'] = torch.stack(log_dict['o_cost']).mean()
|
| 210 |
+
log_dict['act_cost'] = torch.stack(log_dict['act_cost']).mean()
|
| 211 |
+
if self.is_hico: log_dict['tgt_cost'] = torch.stack(log_dict['tgt_cost']).mean()
|
| 212 |
+
if utils.get_rank() == 0: wandb.log(log_dict)
|
| 213 |
+
return return_list, targets
|
| 214 |
+
|
| 215 |
+
def build_hoi_matcher(args):
|
| 216 |
+
return HungarianPairMatcher(args)
|
hotr/models/position_encoding.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
"""
|
| 3 |
+
Various positional encodings for the transformer.
|
| 4 |
+
"""
|
| 5 |
+
import math
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
from hotr.util.misc import NestedTensor
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PositionEmbeddingSine(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
This is a more standard version of the position embedding, very similar to the one
|
| 15 |
+
used by the Attention is all you need paper, generalized to work on images.
|
| 16 |
+
"""
|
| 17 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.num_pos_feats = num_pos_feats
|
| 20 |
+
self.temperature = temperature
|
| 21 |
+
self.normalize = normalize
|
| 22 |
+
if scale is not None and normalize is False:
|
| 23 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 24 |
+
if scale is None:
|
| 25 |
+
scale = 2 * math.pi
|
| 26 |
+
self.scale = scale
|
| 27 |
+
|
| 28 |
+
def forward(self, tensor_list: NestedTensor):
|
| 29 |
+
x = tensor_list.tensors
|
| 30 |
+
mask = tensor_list.mask
|
| 31 |
+
assert mask is not None
|
| 32 |
+
not_mask = ~mask
|
| 33 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
| 34 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
| 35 |
+
if self.normalize:
|
| 36 |
+
eps = 1e-6
|
| 37 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
| 38 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
| 39 |
+
|
| 40 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
| 41 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
| 42 |
+
|
| 43 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
| 44 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
| 45 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
| 46 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
| 47 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 48 |
+
return pos
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class PositionEmbeddingLearned(nn.Module):
|
| 52 |
+
"""
|
| 53 |
+
Absolute pos embedding, learned.
|
| 54 |
+
"""
|
| 55 |
+
def __init__(self, num_pos_feats=256):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.row_embed = nn.Embedding(50, num_pos_feats)
|
| 58 |
+
self.col_embed = nn.Embedding(50, num_pos_feats)
|
| 59 |
+
self.reset_parameters()
|
| 60 |
+
|
| 61 |
+
def reset_parameters(self):
|
| 62 |
+
nn.init.uniform_(self.row_embed.weight)
|
| 63 |
+
nn.init.uniform_(self.col_embed.weight)
|
| 64 |
+
|
| 65 |
+
def forward(self, tensor_list: NestedTensor):
|
| 66 |
+
x = tensor_list.tensors
|
| 67 |
+
h, w = x.shape[-2:]
|
| 68 |
+
i = torch.arange(w, device=x.device)
|
| 69 |
+
j = torch.arange(h, device=x.device)
|
| 70 |
+
x_emb = self.col_embed(i)
|
| 71 |
+
y_emb = self.row_embed(j)
|
| 72 |
+
pos = torch.cat([
|
| 73 |
+
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
| 74 |
+
y_emb.unsqueeze(1).repeat(1, w, 1),
|
| 75 |
+
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
| 76 |
+
return pos
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def build_position_encoding(args):
|
| 80 |
+
N_steps = args.hidden_dim // 2
|
| 81 |
+
if args.position_embedding in ('v2', 'sine'):
|
| 82 |
+
# TODO find a better way of exposing other arguments
|
| 83 |
+
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
| 84 |
+
elif args.position_embedding in ('v3', 'learned'):
|
| 85 |
+
position_embedding = PositionEmbeddingLearned(N_steps)
|
| 86 |
+
else:
|
| 87 |
+
raise ValueError(f"not supported {args.position_embedding}")
|
| 88 |
+
|
| 89 |
+
return position_embedding
|
hotr/models/post_process.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# HOTR official code : hotr/models/post_process.py
|
| 3 |
+
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
|
| 4 |
+
# ------------------------------------------------------------------------
|
| 5 |
+
import time
|
| 6 |
+
import copy
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import nn
|
| 10 |
+
from hotr.util import box_ops
|
| 11 |
+
|
| 12 |
+
class PostProcess(nn.Module):
|
| 13 |
+
""" This module converts the model's output into the format expected by the coco api"""
|
| 14 |
+
def __init__(self, HOIDet):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.HOIDet = HOIDet
|
| 17 |
+
|
| 18 |
+
@torch.no_grad()
|
| 19 |
+
def forward(self, outputs, target_sizes, threshold=0, dataset='coco',args=None):
|
| 20 |
+
""" Perform the computation
|
| 21 |
+
Parameters:
|
| 22 |
+
outputs: raw outputs of the model
|
| 23 |
+
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
|
| 24 |
+
For evaluation, this must be the original image size (before any data augmentation)
|
| 25 |
+
For visualization, this should be the image size after data augment, but before padding
|
| 26 |
+
"""
|
| 27 |
+
out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
|
| 28 |
+
num_path = 1+len(args.augpath_name)
|
| 29 |
+
path_id = args.path_id
|
| 30 |
+
assert len(out_logits) == len(target_sizes)
|
| 31 |
+
assert target_sizes.shape[1] == 2
|
| 32 |
+
|
| 33 |
+
prob = F.softmax(out_logits, -1)
|
| 34 |
+
scores, labels = prob[..., :-1].max(-1)
|
| 35 |
+
|
| 36 |
+
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
|
| 37 |
+
img_h, img_w = target_sizes.unbind(1)
|
| 38 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
| 39 |
+
boxes = boxes * scale_fct[:, None, :]
|
| 40 |
+
|
| 41 |
+
# Preidction Branch for HOI detection
|
| 42 |
+
if self.HOIDet:
|
| 43 |
+
if dataset == 'vcoco':
|
| 44 |
+
""" Compute HOI triplet prediction score for V-COCO.
|
| 45 |
+
Our scoring function follows the implementation details of UnionDet.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
out_time = outputs['hoi_recognition_time']
|
| 49 |
+
bss,q,hd=outputs['pred_hidx'].shape
|
| 50 |
+
start_time = time.time()
|
| 51 |
+
pair_actions = torch.sigmoid(outputs['pred_actions'][:,path_id,...])
|
| 52 |
+
h_prob = F.softmax(outputs['pred_hidx'].view(num_path,bss//num_path,q,hd)[path_id], -1)
|
| 53 |
+
h_idx_score, h_indices = h_prob.max(-1)
|
| 54 |
+
|
| 55 |
+
o_prob = F.softmax(outputs['pred_oidx'].view(num_path,bss//num_path,q,hd)[path_id], -1)
|
| 56 |
+
o_idx_score, o_indices = o_prob.max(-1)
|
| 57 |
+
hoi_recognition_time = (time.time() - start_time) + out_time
|
| 58 |
+
# import pdb;pdb.set_trace()
|
| 59 |
+
results = []
|
| 60 |
+
# iterate for batch size
|
| 61 |
+
for batch_idx, (s, l, b) in enumerate(zip(scores, labels, boxes)):
|
| 62 |
+
h_inds = (l == 1) & (s > threshold)
|
| 63 |
+
o_inds = (s > threshold)
|
| 64 |
+
|
| 65 |
+
h_box, h_cat = b[h_inds], s[h_inds]
|
| 66 |
+
o_box, o_cat = b[o_inds], s[o_inds]
|
| 67 |
+
|
| 68 |
+
# for scenario 1 in v-coco dataset
|
| 69 |
+
o_inds = torch.cat((o_inds, torch.ones(1).type(torch.bool).to(o_inds.device)))
|
| 70 |
+
o_box = torch.cat((o_box, torch.Tensor([0, 0, 0, 0]).unsqueeze(0).to(o_box.device)))
|
| 71 |
+
|
| 72 |
+
result_dict = {
|
| 73 |
+
'h_box': h_box, 'h_cat': h_cat,
|
| 74 |
+
'o_box': o_box, 'o_cat': o_cat,
|
| 75 |
+
'scores': s, 'labels': l, 'boxes': b
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
h_inds_lst = (h_inds == True).nonzero(as_tuple=False).squeeze(-1)
|
| 79 |
+
o_inds_lst = (o_inds == True).nonzero(as_tuple=False).squeeze(-1)
|
| 80 |
+
|
| 81 |
+
K = boxes.shape[1]
|
| 82 |
+
n_act = pair_actions[batch_idx][:, :-1].shape[-1]
|
| 83 |
+
score = torch.zeros((n_act, K, K+1)).to(pair_actions[batch_idx].device)
|
| 84 |
+
sorted_score = torch.zeros((n_act, K, K+1)).to(pair_actions[batch_idx].device)
|
| 85 |
+
id_score = torch.zeros((K, K+1)).to(pair_actions[batch_idx].device)
|
| 86 |
+
# import pdb;pdb.set_trace()
|
| 87 |
+
# Score function
|
| 88 |
+
for hs, h_idx, os, o_idx, pair_action in zip(h_idx_score[batch_idx], h_indices[batch_idx], o_idx_score[batch_idx], o_indices[batch_idx], pair_actions[batch_idx]):
|
| 89 |
+
matching_score = (1-pair_action[-1]) # no interaction score
|
| 90 |
+
if h_idx == o_idx: o_idx = -1
|
| 91 |
+
if matching_score > id_score[h_idx, o_idx]:
|
| 92 |
+
id_score[h_idx, o_idx] = matching_score
|
| 93 |
+
sorted_score[:, h_idx, o_idx] = matching_score * pair_action[:-1]
|
| 94 |
+
score[:, h_idx, o_idx] += matching_score * pair_action[:-1]
|
| 95 |
+
|
| 96 |
+
score += sorted_score
|
| 97 |
+
score = score[:, h_inds, :]
|
| 98 |
+
score = score[:, :, o_inds]
|
| 99 |
+
|
| 100 |
+
result_dict.update({
|
| 101 |
+
'pair_score': score,
|
| 102 |
+
'hoi_recognition_time': hoi_recognition_time,
|
| 103 |
+
})
|
| 104 |
+
|
| 105 |
+
results.append(result_dict)
|
| 106 |
+
|
| 107 |
+
elif dataset == 'hico-det':
|
| 108 |
+
""" Compute HOI triplet prediction score for HICO-DET.
|
| 109 |
+
For HICO-DET, we follow the same scoring function but do not accumulate the results.
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
bss,q,hd=outputs['pred_hidx'].shape
|
| 113 |
+
out_time = outputs['hoi_recognition_time']
|
| 114 |
+
a,b,c=outputs['pred_obj_logits'].shape
|
| 115 |
+
start_time = time.time()
|
| 116 |
+
out_obj_logits, out_verb_logits = outputs['pred_obj_logits'].view(-1,num_path,b,c)[:,path_id,...], outputs['pred_actions'][:,path_id,...]
|
| 117 |
+
out_verb_logits = outputs['pred_actions'][:,path_id,...]
|
| 118 |
+
|
| 119 |
+
# actions
|
| 120 |
+
matching_scores = (1-out_verb_logits.sigmoid()[..., -1:]) #* (1-out_verb_logits.sigmoid()[..., 57:58])
|
| 121 |
+
verb_scores = out_verb_logits.sigmoid()[..., :-1] * matching_scores
|
| 122 |
+
|
| 123 |
+
# hbox, obox
|
| 124 |
+
outputs_hrepr, outputs_orepr = outputs['pred_hidx'].view(num_path,bss//num_path,q,hd)[path_id], outputs['pred_oidx'].view(num_path,bss//num_path,q,hd)[path_id]
|
| 125 |
+
obj_scores, obj_labels = F.softmax(out_obj_logits, -1)[..., :-1].max(-1)
|
| 126 |
+
|
| 127 |
+
h_prob = F.softmax(outputs_hrepr, -1)
|
| 128 |
+
h_idx_score, h_indices = h_prob.max(-1)
|
| 129 |
+
|
| 130 |
+
# targets
|
| 131 |
+
o_prob = F.softmax(outputs_orepr, -1)
|
| 132 |
+
o_idx_score, o_indices = o_prob.max(-1)
|
| 133 |
+
hoi_recognition_time = (time.time() - start_time) + out_time
|
| 134 |
+
|
| 135 |
+
# hidx, oidx
|
| 136 |
+
sub_boxes, obj_boxes = [], []
|
| 137 |
+
for batch_id, (box, h_idx, o_idx) in enumerate(zip(boxes, h_indices, o_indices)):
|
| 138 |
+
sub_boxes.append(box[h_idx, :])
|
| 139 |
+
obj_boxes.append(box[o_idx, :])
|
| 140 |
+
sub_boxes = torch.stack(sub_boxes, dim=0)
|
| 141 |
+
obj_boxes = torch.stack(obj_boxes, dim=0)
|
| 142 |
+
|
| 143 |
+
# accumulate results (iterate through interaction queries)
|
| 144 |
+
results = []
|
| 145 |
+
for os, ol, vs, ms, sb, ob in zip(obj_scores, obj_labels, verb_scores, matching_scores, sub_boxes, obj_boxes):
|
| 146 |
+
sl = torch.full_like(ol, 0) # self.subject_category_id = 0 in HICO-DET
|
| 147 |
+
l = torch.cat((sl, ol))
|
| 148 |
+
b = torch.cat((sb, ob))
|
| 149 |
+
results.append({'labels': l.to('cpu'), 'boxes': b.to('cpu')})
|
| 150 |
+
vs = vs * os.unsqueeze(1)
|
| 151 |
+
ids = torch.arange(b.shape[0])
|
| 152 |
+
res_dict = {
|
| 153 |
+
'verb_scores': vs.to('cpu'),
|
| 154 |
+
'sub_ids': ids[:ids.shape[0] // 2],
|
| 155 |
+
'obj_ids': ids[ids.shape[0] // 2:],
|
| 156 |
+
'hoi_recognition_time': hoi_recognition_time
|
| 157 |
+
}
|
| 158 |
+
results[-1].update(res_dict)
|
| 159 |
+
else:
|
| 160 |
+
results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]
|
| 161 |
+
|
| 162 |
+
return results
|
hotr/models/transformer.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# HOTR official code : hotr/models/transformer.py
|
| 3 |
+
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
|
| 4 |
+
# ------------------------------------------------------------------------
|
| 5 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 6 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 7 |
+
# ------------------------------------------------------------------------
|
| 8 |
+
"""
|
| 9 |
+
DETR & HOTR Transformer class.
|
| 10 |
+
Copy-paste from torch.nn.Transformer with modifications:
|
| 11 |
+
* positional encodings are passed in MHattention
|
| 12 |
+
* extra LN at the end of encoder is removed
|
| 13 |
+
* decoder returns a stack of activations from all decoding layers
|
| 14 |
+
"""
|
| 15 |
+
import copy
|
| 16 |
+
from typing import Optional, List
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from torch import nn, Tensor
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Transformer(nn.Module):
|
| 24 |
+
|
| 25 |
+
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
|
| 26 |
+
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
|
| 27 |
+
activation="relu", normalize_before=False,
|
| 28 |
+
return_intermediate_dec=False):
|
| 29 |
+
super().__init__()
|
| 30 |
+
|
| 31 |
+
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
| 32 |
+
dropout, activation, normalize_before)
|
| 33 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
| 34 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
| 35 |
+
|
| 36 |
+
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
|
| 37 |
+
dropout, activation, normalize_before)
|
| 38 |
+
decoder_norm = nn.LayerNorm(d_model)
|
| 39 |
+
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
|
| 40 |
+
return_intermediate=return_intermediate_dec)
|
| 41 |
+
|
| 42 |
+
self._reset_parameters()
|
| 43 |
+
self.d_model = d_model
|
| 44 |
+
self.nhead = nhead
|
| 45 |
+
|
| 46 |
+
def _reset_parameters(self):
|
| 47 |
+
for p in self.parameters():
|
| 48 |
+
if p.dim() > 1:
|
| 49 |
+
nn.init.xavier_uniform_(p)
|
| 50 |
+
|
| 51 |
+
def forward(self, src, mask, query_embed, pos_embed,query_obj=None, return_decoder_input=False):
|
| 52 |
+
# flatten NxCxHxW to HWxNxC
|
| 53 |
+
bs, c, h, w = src.shape
|
| 54 |
+
src = src.flatten(2).permute(2, 0, 1)
|
| 55 |
+
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
|
| 56 |
+
|
| 57 |
+
if query_embed.dim()==2:
|
| 58 |
+
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
| 59 |
+
mask = mask.flatten(1)
|
| 60 |
+
|
| 61 |
+
tgt = torch.zeros_like(query_embed)
|
| 62 |
+
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
| 63 |
+
if query_obj is None:
|
| 64 |
+
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed)
|
| 65 |
+
else:
|
| 66 |
+
|
| 67 |
+
hs = self.decoder(query_obj, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed)
|
| 68 |
+
|
| 69 |
+
return hs.transpose(1, 2), memory,
|
| 70 |
+
|
| 71 |
+
class TransformerEncoder(nn.Module):
|
| 72 |
+
|
| 73 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
| 76 |
+
self.num_layers = num_layers
|
| 77 |
+
self.norm = norm
|
| 78 |
+
|
| 79 |
+
def forward(self, src,
|
| 80 |
+
mask: Optional[Tensor] = None,
|
| 81 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 82 |
+
pos: Optional[Tensor] = None):
|
| 83 |
+
output = src
|
| 84 |
+
|
| 85 |
+
for layer in self.layers:
|
| 86 |
+
output = layer(output, src_mask=mask,
|
| 87 |
+
src_key_padding_mask=src_key_padding_mask, pos=pos)
|
| 88 |
+
|
| 89 |
+
if self.norm is not None:
|
| 90 |
+
output = self.norm(output)
|
| 91 |
+
|
| 92 |
+
return output
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class TransformerDecoder(nn.Module):
|
| 96 |
+
|
| 97 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
| 100 |
+
self.num_layers = num_layers
|
| 101 |
+
self.norm = norm
|
| 102 |
+
self.return_intermediate = return_intermediate
|
| 103 |
+
|
| 104 |
+
def forward(self, tgt, memory,
|
| 105 |
+
tgt_mask: Optional[Tensor] = None,
|
| 106 |
+
memory_mask: Optional[Tensor] = None,
|
| 107 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 108 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 109 |
+
pos: Optional[Tensor] = None,
|
| 110 |
+
query_pos: Optional[Tensor] = None):
|
| 111 |
+
output = tgt
|
| 112 |
+
|
| 113 |
+
intermediate = []
|
| 114 |
+
|
| 115 |
+
for layer in self.layers:
|
| 116 |
+
output = layer(output, memory, tgt_mask=tgt_mask,
|
| 117 |
+
memory_mask=memory_mask,
|
| 118 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
| 119 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
| 120 |
+
pos=pos, query_pos=query_pos)
|
| 121 |
+
if self.return_intermediate:
|
| 122 |
+
intermediate.append(self.norm(output))
|
| 123 |
+
|
| 124 |
+
if self.norm is not None:
|
| 125 |
+
output = self.norm(output)
|
| 126 |
+
if self.return_intermediate:
|
| 127 |
+
intermediate.pop()
|
| 128 |
+
intermediate.append(output)
|
| 129 |
+
|
| 130 |
+
if self.return_intermediate:
|
| 131 |
+
return torch.stack(intermediate)
|
| 132 |
+
|
| 133 |
+
return output.unsqueeze(0)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class TransformerEncoderLayer(nn.Module):
|
| 137 |
+
|
| 138 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
| 139 |
+
activation="relu", normalize_before=False):
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 142 |
+
# Implementation of Feedforward model
|
| 143 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 144 |
+
self.dropout = nn.Dropout(dropout)
|
| 145 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 146 |
+
|
| 147 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 148 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 149 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 150 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 151 |
+
|
| 152 |
+
self.activation = _get_activation_fn(activation)
|
| 153 |
+
self.normalize_before = normalize_before
|
| 154 |
+
|
| 155 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
| 156 |
+
return tensor if pos is None else tensor + pos
|
| 157 |
+
|
| 158 |
+
def forward_post(self,
|
| 159 |
+
src,
|
| 160 |
+
src_mask: Optional[Tensor] = None,
|
| 161 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 162 |
+
pos: Optional[Tensor] = None):
|
| 163 |
+
q = k = self.with_pos_embed(src, pos)
|
| 164 |
+
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
|
| 165 |
+
key_padding_mask=src_key_padding_mask)[0]
|
| 166 |
+
src = src + self.dropout1(src2)
|
| 167 |
+
src = self.norm1(src)
|
| 168 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
| 169 |
+
src = src + self.dropout2(src2)
|
| 170 |
+
src = self.norm2(src)
|
| 171 |
+
return src
|
| 172 |
+
|
| 173 |
+
def forward_pre(self, src,
|
| 174 |
+
src_mask: Optional[Tensor] = None,
|
| 175 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 176 |
+
pos: Optional[Tensor] = None):
|
| 177 |
+
src2 = self.norm1(src)
|
| 178 |
+
q = k = self.with_pos_embed(src2, pos)
|
| 179 |
+
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
|
| 180 |
+
key_padding_mask=src_key_padding_mask)[0]
|
| 181 |
+
src = src + self.dropout1(src2)
|
| 182 |
+
src2 = self.norm2(src)
|
| 183 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
| 184 |
+
src = src + self.dropout2(src2)
|
| 185 |
+
return src
|
| 186 |
+
|
| 187 |
+
def forward(self, src,
|
| 188 |
+
src_mask: Optional[Tensor] = None,
|
| 189 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 190 |
+
pos: Optional[Tensor] = None):
|
| 191 |
+
if self.normalize_before:
|
| 192 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
| 193 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class TransformerDecoderLayer(nn.Module):
|
| 197 |
+
|
| 198 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
| 199 |
+
activation="relu", normalize_before=False):
|
| 200 |
+
super().__init__()
|
| 201 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 202 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 203 |
+
# Implementation of Feedforward model
|
| 204 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 205 |
+
self.dropout = nn.Dropout(dropout)
|
| 206 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 207 |
+
|
| 208 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 209 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 210 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 211 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 212 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 213 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 214 |
+
|
| 215 |
+
self.activation = _get_activation_fn(activation)
|
| 216 |
+
self.normalize_before = normalize_before
|
| 217 |
+
|
| 218 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
| 219 |
+
return tensor if pos is None else tensor + pos
|
| 220 |
+
|
| 221 |
+
def forward_post(self, tgt, memory,
|
| 222 |
+
tgt_mask: Optional[Tensor] = None,
|
| 223 |
+
memory_mask: Optional[Tensor] = None,
|
| 224 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 225 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 226 |
+
pos: Optional[Tensor] = None,
|
| 227 |
+
query_pos: Optional[Tensor] = None):
|
| 228 |
+
|
| 229 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
| 230 |
+
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
|
| 231 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
| 232 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 233 |
+
tgt = self.norm1(tgt)
|
| 234 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
| 235 |
+
key=self.with_pos_embed(memory, pos),
|
| 236 |
+
value=memory, attn_mask=memory_mask,
|
| 237 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
| 238 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 239 |
+
tgt = self.norm2(tgt)
|
| 240 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
| 241 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 242 |
+
tgt = self.norm3(tgt)
|
| 243 |
+
return tgt
|
| 244 |
+
|
| 245 |
+
def forward_pre(self, tgt, memory,
|
| 246 |
+
tgt_mask: Optional[Tensor] = None,
|
| 247 |
+
memory_mask: Optional[Tensor] = None,
|
| 248 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 249 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 250 |
+
pos: Optional[Tensor] = None,
|
| 251 |
+
query_pos: Optional[Tensor] = None):
|
| 252 |
+
tgt2 = self.norm1(tgt)
|
| 253 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
| 254 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
| 255 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
| 256 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 257 |
+
tgt2 = self.norm2(tgt)
|
| 258 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
| 259 |
+
key=self.with_pos_embed(memory, pos),
|
| 260 |
+
value=memory, attn_mask=memory_mask,
|
| 261 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
| 262 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 263 |
+
tgt2 = self.norm3(tgt)
|
| 264 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
| 265 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 266 |
+
return tgt
|
| 267 |
+
|
| 268 |
+
def forward(self, tgt, memory,
|
| 269 |
+
tgt_mask: Optional[Tensor] = None,
|
| 270 |
+
memory_mask: Optional[Tensor] = None,
|
| 271 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 272 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 273 |
+
pos: Optional[Tensor] = None,
|
| 274 |
+
query_pos: Optional[Tensor] = None):
|
| 275 |
+
if self.normalize_before:
|
| 276 |
+
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
|
| 277 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
| 278 |
+
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
|
| 279 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def _get_clones(module, N):
|
| 283 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def build_transformer(args):
|
| 287 |
+
return Transformer(
|
| 288 |
+
d_model=args.hidden_dim,
|
| 289 |
+
dropout=args.dropout,
|
| 290 |
+
nhead=args.nheads,
|
| 291 |
+
dim_feedforward=args.dim_feedforward,
|
| 292 |
+
num_encoder_layers=args.enc_layers,
|
| 293 |
+
num_decoder_layers=args.dec_layers,
|
| 294 |
+
normalize_before=args.pre_norm,
|
| 295 |
+
return_intermediate_dec=True,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def build_hoi_transformer(args):
|
| 300 |
+
return Transformer(
|
| 301 |
+
d_model=args.hidden_dim,
|
| 302 |
+
dropout=args.dropout,
|
| 303 |
+
nhead=args.hoi_nheads,
|
| 304 |
+
dim_feedforward=args.hoi_dim_feedforward,
|
| 305 |
+
num_encoder_layers=args.hoi_enc_layers,
|
| 306 |
+
num_decoder_layers=args.hoi_dec_layers,
|
| 307 |
+
normalize_before=args.pre_norm,
|
| 308 |
+
return_intermediate_dec=True,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def _get_activation_fn(activation):
|
| 313 |
+
"""Return an activation function given a string"""
|
| 314 |
+
if activation == "relu":
|
| 315 |
+
return F.relu
|
| 316 |
+
if activation == "gelu":
|
| 317 |
+
return F.gelu
|
| 318 |
+
if activation == "glu":
|
| 319 |
+
return F.glu
|
| 320 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
hotr/util/__init__.py
ADDED
|
File without changes
|
hotr/util/box_ops.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
"""
|
| 3 |
+
Utilities for bounding box manipulation and GIoU.
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
from torchvision.ops.boxes import box_area
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def box_cxcywh_to_xyxy(x):
|
| 10 |
+
x_c, y_c, w, h = x.unbind(-1)
|
| 11 |
+
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
|
| 12 |
+
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
| 13 |
+
return torch.stack(b, dim=-1)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def box_xyxy_to_cxcywh(x):
|
| 17 |
+
x0, y0, x1, y1 = x.unbind(-1)
|
| 18 |
+
b = [(x0 + x1) / 2, (y0 + y1) / 2,
|
| 19 |
+
(x1 - x0), (y1 - y0)]
|
| 20 |
+
return torch.stack(b, dim=-1)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# modified from torchvision to also return the union
|
| 24 |
+
def box_iou(boxes1, boxes2):
|
| 25 |
+
area1 = box_area(boxes1)
|
| 26 |
+
area2 = box_area(boxes2)
|
| 27 |
+
|
| 28 |
+
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
| 29 |
+
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
| 30 |
+
|
| 31 |
+
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
| 32 |
+
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
| 33 |
+
|
| 34 |
+
union = area1[:, None] + area2 - inter
|
| 35 |
+
|
| 36 |
+
iou = inter / union
|
| 37 |
+
return iou, union
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def generalized_box_iou(boxes1, boxes2):
|
| 41 |
+
"""
|
| 42 |
+
Generalized IoU from https://giou.stanford.edu/
|
| 43 |
+
The boxes should be in [x0, y0, x1, y1] format
|
| 44 |
+
Returns a [N, M] pairwise matrix, where N = len(boxes1)
|
| 45 |
+
and M = len(boxes2)
|
| 46 |
+
"""
|
| 47 |
+
# degenerate boxes gives inf / nan results
|
| 48 |
+
# so do an early check
|
| 49 |
+
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
| 50 |
+
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
| 51 |
+
iou, union = box_iou(boxes1, boxes2)
|
| 52 |
+
|
| 53 |
+
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
| 54 |
+
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
| 55 |
+
|
| 56 |
+
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
| 57 |
+
area = wh[:, :, 0] * wh[:, :, 1]
|
| 58 |
+
|
| 59 |
+
return iou - (area - union) / area
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def masks_to_boxes(masks):
|
| 63 |
+
"""Compute the bounding boxes around the provided masks
|
| 64 |
+
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
|
| 65 |
+
Returns a [N, 4] tensors, with the boxes in xyxy format
|
| 66 |
+
"""
|
| 67 |
+
if masks.numel() == 0:
|
| 68 |
+
return torch.zeros((0, 4), device=masks.device)
|
| 69 |
+
|
| 70 |
+
h, w = masks.shape[-2:]
|
| 71 |
+
|
| 72 |
+
y = torch.arange(0, h, dtype=torch.float)
|
| 73 |
+
x = torch.arange(0, w, dtype=torch.float)
|
| 74 |
+
y, x = torch.meshgrid(y, x)
|
| 75 |
+
|
| 76 |
+
x_mask = (masks * x.unsqueeze(0))
|
| 77 |
+
x_max = x_mask.flatten(1).max(-1)[0]
|
| 78 |
+
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
| 79 |
+
|
| 80 |
+
y_mask = (masks * y.unsqueeze(0))
|
| 81 |
+
y_max = y_mask.flatten(1).max(-1)[0]
|
| 82 |
+
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
| 83 |
+
|
| 84 |
+
return torch.stack([x_min, y_min, x_max, y_max], 1)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def rescale_bboxes(out_bbox, size):
|
| 88 |
+
img_h, img_w = size
|
| 89 |
+
b = box_cxcywh_to_xyxy(out_bbox)
|
| 90 |
+
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(out_bbox.get_device())
|
| 91 |
+
return b
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def rescale_pairs(out_pairs, size):
|
| 95 |
+
img_h, img_w = size
|
| 96 |
+
h_bbox = out_pairs[:, :4]
|
| 97 |
+
o_bbox = out_pairs[:, 4:]
|
| 98 |
+
|
| 99 |
+
h = box_cxcywh_to_xyxy(h_bbox)
|
| 100 |
+
h = h * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(h_bbox.get_device())
|
| 101 |
+
|
| 102 |
+
obj_mask = (o_bbox[:, 0] != -1)
|
| 103 |
+
if obj_mask.sum() != 0:
|
| 104 |
+
o = box_cxcywh_to_xyxy(o_bbox)
|
| 105 |
+
o = o * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(o_bbox.get_device())
|
| 106 |
+
o_bbox[obj_mask] = o[obj_mask]
|
| 107 |
+
o = o_bbox
|
| 108 |
+
p = torch.cat([h, o], dim=-1)
|
| 109 |
+
|
| 110 |
+
return p
|
hotr/util/logger.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# HOTR official code : hotr/util/logger.py
|
| 3 |
+
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
|
| 4 |
+
# ------------------------------------------------------------------------
|
| 5 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 6 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 7 |
+
# ------------------------------------------------------------------------
|
| 8 |
+
import torch
|
| 9 |
+
import time
|
| 10 |
+
import datetime
|
| 11 |
+
import sys
|
| 12 |
+
from time import sleep
|
| 13 |
+
from collections import defaultdict
|
| 14 |
+
|
| 15 |
+
from hotr.util.misc import SmoothedValue
|
| 16 |
+
|
| 17 |
+
def print_params(model):
|
| 18 |
+
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 19 |
+
print('\n[Logger] Number of params: ', n_parameters)
|
| 20 |
+
return n_parameters
|
| 21 |
+
|
| 22 |
+
def print_args(args):
|
| 23 |
+
print('\n[Logger] DETR Arguments:')
|
| 24 |
+
for k, v in vars(args).items():
|
| 25 |
+
if k in [
|
| 26 |
+
'lr', 'lr_backbone', 'lr_drop',
|
| 27 |
+
'frozen_weights',
|
| 28 |
+
'backbone', 'dilation',
|
| 29 |
+
'position_embedding', 'enc_layers', 'dec_layers', 'num_queries',
|
| 30 |
+
'dataset_file']:
|
| 31 |
+
print(f'\t{k}: {v}')
|
| 32 |
+
|
| 33 |
+
if args.HOIDet:
|
| 34 |
+
print('\n[Logger] DETR_HOI Arguments:')
|
| 35 |
+
for k, v in vars(args).items():
|
| 36 |
+
if k in [
|
| 37 |
+
'freeze_enc',
|
| 38 |
+
'query_flag',
|
| 39 |
+
'hoi_nheads',
|
| 40 |
+
'hoi_dim_feedforward',
|
| 41 |
+
'hoi_dec_layers',
|
| 42 |
+
'hoi_idx_loss_coef',
|
| 43 |
+
'hoi_act_loss_coef',
|
| 44 |
+
'hoi_eos_coef',
|
| 45 |
+
'object_threshold']:
|
| 46 |
+
print(f'\t{k}: {v}')
|
| 47 |
+
|
| 48 |
+
class MetricLogger(object):
|
| 49 |
+
def __init__(self, mode="test", delimiter="\t"):
|
| 50 |
+
self.meters = defaultdict(SmoothedValue)
|
| 51 |
+
self.delimiter = delimiter
|
| 52 |
+
self.mode = mode
|
| 53 |
+
|
| 54 |
+
def update(self, **kwargs):
|
| 55 |
+
for k, v in kwargs.items():
|
| 56 |
+
if isinstance(v, torch.Tensor):
|
| 57 |
+
v = v.item()
|
| 58 |
+
assert isinstance(v, (float, int))
|
| 59 |
+
self.meters[k].update(v)
|
| 60 |
+
|
| 61 |
+
def __getattr__(self, attr):
|
| 62 |
+
if attr in self.meters:
|
| 63 |
+
return self.meters[attr]
|
| 64 |
+
if attr in self.__dict__:
|
| 65 |
+
return self.__dict__[attr]
|
| 66 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
| 67 |
+
type(self).__name__, attr))
|
| 68 |
+
|
| 69 |
+
def __str__(self):
|
| 70 |
+
loss_str = []
|
| 71 |
+
for name, meter in self.meters.items():
|
| 72 |
+
loss_str.append(
|
| 73 |
+
"{}: {}".format(name, str(meter))
|
| 74 |
+
)
|
| 75 |
+
return self.delimiter.join(loss_str)
|
| 76 |
+
|
| 77 |
+
def synchronize_between_processes(self):
|
| 78 |
+
for meter in self.meters.values():
|
| 79 |
+
meter.synchronize_between_processes()
|
| 80 |
+
|
| 81 |
+
def add_meter(self, name, meter):
|
| 82 |
+
self.meters[name] = meter
|
| 83 |
+
|
| 84 |
+
def log_every(self, iterable, print_freq, header=None):
|
| 85 |
+
i = 0
|
| 86 |
+
if not header:
|
| 87 |
+
header = ''
|
| 88 |
+
start_time = time.time()
|
| 89 |
+
end = time.time()
|
| 90 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
| 91 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
| 92 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
| 93 |
+
if torch.cuda.is_available():
|
| 94 |
+
log_msg = self.delimiter.join([
|
| 95 |
+
header,
|
| 96 |
+
'[{0' + space_fmt + '}/{1}]',
|
| 97 |
+
'eta: {eta}',
|
| 98 |
+
'{meters}',
|
| 99 |
+
'time: {time}',
|
| 100 |
+
'data: {data}',
|
| 101 |
+
'max mem: {memory:.0f}'
|
| 102 |
+
])
|
| 103 |
+
else:
|
| 104 |
+
log_msg = self.delimiter.join([
|
| 105 |
+
header,
|
| 106 |
+
'[{0' + space_fmt + '}/{1}]',
|
| 107 |
+
'eta: {eta}',
|
| 108 |
+
'{meters}',
|
| 109 |
+
'time: {time}',
|
| 110 |
+
'data: {data}'
|
| 111 |
+
])
|
| 112 |
+
MB = 1024.0 * 1024.0
|
| 113 |
+
for obj in iterable:
|
| 114 |
+
data_time.update(time.time() - end)
|
| 115 |
+
yield obj
|
| 116 |
+
iter_time.update(time.time() - end)
|
| 117 |
+
|
| 118 |
+
if (i % print_freq == 0 and i !=0) or i == len(iterable) - 1:
|
| 119 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
| 120 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 121 |
+
if torch.cuda.is_available():
|
| 122 |
+
print(log_msg.format(
|
| 123 |
+
i+1, len(iterable), eta=eta_string,
|
| 124 |
+
meters=str(self),
|
| 125 |
+
time=str(iter_time), data=str(data_time),
|
| 126 |
+
memory=torch.cuda.max_memory_allocated() / MB),
|
| 127 |
+
flush=(self.mode=='test'), end=("\r" if self.mode=='test' else "\n"))
|
| 128 |
+
else:
|
| 129 |
+
print(log_msg.format(
|
| 130 |
+
i+1, len(iterable), eta=eta_string,
|
| 131 |
+
meters=str(self),
|
| 132 |
+
time=str(iter_time), data=str(data_time)),
|
| 133 |
+
flush=(self.mode=='test'), end=("\r" if self.mode=='test' else "\n"))
|
| 134 |
+
else:
|
| 135 |
+
log_interval = self.delimiter.join([header, '[{0' + space_fmt + '}/{1}]'])
|
| 136 |
+
if torch.cuda.is_available(): print(log_interval.format(i+1, len(iterable)), flush=True, end="\r")
|
| 137 |
+
else: print(log_interval.format(i+1, len(iterable)), flush=True, end="\r")
|
| 138 |
+
|
| 139 |
+
i += 1
|
| 140 |
+
end = time.time()
|
| 141 |
+
total_time = time.time() - start_time
|
| 142 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 143 |
+
if self.mode=='test': print("")
|
| 144 |
+
print('[stats] Total Time ({}) : {} ({:.4f} s / it)'.format(
|
| 145 |
+
self.mode, total_time_str, total_time / len(iterable)))
|
hotr/util/misc.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# HOTR official code : hotr/util/misc.py
|
| 3 |
+
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
|
| 4 |
+
# ------------------------------------------------------------------------
|
| 5 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 6 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 7 |
+
# ------------------------------------------------------------------------
|
| 8 |
+
"""
|
| 9 |
+
Misc functions, including distributed helpers.
|
| 10 |
+
Mostly copy-paste from torchvision references.
|
| 11 |
+
"""
|
| 12 |
+
import os
|
| 13 |
+
import subprocess
|
| 14 |
+
from collections import deque
|
| 15 |
+
import pickle
|
| 16 |
+
import socket
|
| 17 |
+
from typing import Optional, List
|
| 18 |
+
import ast
|
| 19 |
+
import torch
|
| 20 |
+
import torch.distributed as dist
|
| 21 |
+
from torch import Tensor
|
| 22 |
+
|
| 23 |
+
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
| 24 |
+
import torchvision
|
| 25 |
+
if float(torchvision.__version__[:3]) < 0.7:
|
| 26 |
+
from torchvision.ops import _new_empty_tensor
|
| 27 |
+
from torchvision.ops.misc import _output_size
|
| 28 |
+
|
| 29 |
+
os.environ['MASTER_PORT']='8993'
|
| 30 |
+
class SmoothedValue(object):
|
| 31 |
+
"""Track a series of values and provide access to smoothed values over a
|
| 32 |
+
window or the global series average.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, window_size=20, fmt=None):
|
| 36 |
+
if fmt is None:
|
| 37 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
| 38 |
+
self.deque = deque(maxlen=window_size)
|
| 39 |
+
self.total = 0.0
|
| 40 |
+
self.count = 0
|
| 41 |
+
self.fmt = fmt
|
| 42 |
+
|
| 43 |
+
def update(self, value, n=1):
|
| 44 |
+
self.deque.append(value)
|
| 45 |
+
self.count += n
|
| 46 |
+
self.total += value * n
|
| 47 |
+
|
| 48 |
+
def synchronize_between_processes(self):
|
| 49 |
+
"""
|
| 50 |
+
Warning: does not synchronize the deque!
|
| 51 |
+
"""
|
| 52 |
+
if not is_dist_avail_and_initialized():
|
| 53 |
+
return
|
| 54 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
| 55 |
+
dist.barrier()
|
| 56 |
+
dist.all_reduce(t)
|
| 57 |
+
t = t.tolist()
|
| 58 |
+
self.count = int(t[0])
|
| 59 |
+
self.total = t[1]
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def median(self):
|
| 63 |
+
d = torch.tensor(list(self.deque))
|
| 64 |
+
return d.median().item()
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def avg(self):
|
| 68 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
| 69 |
+
return d.mean().item()
|
| 70 |
+
|
| 71 |
+
@property
|
| 72 |
+
def global_avg(self):
|
| 73 |
+
return self.total / self.count
|
| 74 |
+
|
| 75 |
+
@property
|
| 76 |
+
def max(self):
|
| 77 |
+
return max(self.deque)
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def value(self):
|
| 81 |
+
return self.deque[-1]
|
| 82 |
+
|
| 83 |
+
def __str__(self):
|
| 84 |
+
return self.fmt.format(
|
| 85 |
+
median=self.median,
|
| 86 |
+
avg=self.avg,
|
| 87 |
+
global_avg=self.global_avg,
|
| 88 |
+
max=self.max,
|
| 89 |
+
value=self.value)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def all_gather(data):
|
| 93 |
+
"""
|
| 94 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
| 95 |
+
Args:
|
| 96 |
+
data: any picklable object
|
| 97 |
+
Returns:
|
| 98 |
+
list[data]: list of data gathered from each rank
|
| 99 |
+
"""
|
| 100 |
+
world_size = get_world_size()
|
| 101 |
+
if world_size == 1:
|
| 102 |
+
return [data]
|
| 103 |
+
|
| 104 |
+
# serialized to a Tensor
|
| 105 |
+
buffer = pickle.dumps(data)
|
| 106 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
| 107 |
+
tensor = torch.ByteTensor(storage).to("cuda")
|
| 108 |
+
|
| 109 |
+
# obtain Tensor size of each rank
|
| 110 |
+
local_size = torch.tensor([tensor.numel()], device="cuda")
|
| 111 |
+
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
| 112 |
+
dist.all_gather(size_list, local_size)
|
| 113 |
+
size_list = [int(size.item()) for size in size_list]
|
| 114 |
+
max_size = max(size_list)
|
| 115 |
+
|
| 116 |
+
# receiving Tensor from all ranks
|
| 117 |
+
# we pad the tensor because torch all_gather does not support
|
| 118 |
+
# gathering tensors of different shapes
|
| 119 |
+
tensor_list = []
|
| 120 |
+
for _ in size_list:
|
| 121 |
+
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
| 122 |
+
if local_size != max_size:
|
| 123 |
+
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
| 124 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
| 125 |
+
dist.all_gather(tensor_list, tensor)
|
| 126 |
+
|
| 127 |
+
data_list = []
|
| 128 |
+
for size, tensor in zip(size_list, tensor_list):
|
| 129 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
| 130 |
+
data_list.append(pickle.loads(buffer))
|
| 131 |
+
|
| 132 |
+
return data_list
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def reduce_dict(input_dict, average=True):
|
| 136 |
+
"""
|
| 137 |
+
Args:
|
| 138 |
+
input_dict (dict): all the values will be reduced
|
| 139 |
+
average (bool): whether to do average or sum
|
| 140 |
+
Reduce the values in the dictionary from all processes so that all processes
|
| 141 |
+
have the averaged results. Returns a dict with the same fields as
|
| 142 |
+
input_dict, after reduction.
|
| 143 |
+
"""
|
| 144 |
+
world_size = get_world_size()
|
| 145 |
+
if world_size < 2:
|
| 146 |
+
return input_dict
|
| 147 |
+
with torch.no_grad():
|
| 148 |
+
names = []
|
| 149 |
+
values = []
|
| 150 |
+
# sort the keys so that they are consistent across processes
|
| 151 |
+
for k in sorted(input_dict.keys()):
|
| 152 |
+
names.append(k)
|
| 153 |
+
values.append(input_dict[k])
|
| 154 |
+
values = torch.stack(values, dim=0)
|
| 155 |
+
dist.all_reduce(values)
|
| 156 |
+
if average:
|
| 157 |
+
values /= world_size
|
| 158 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
| 159 |
+
return reduced_dict
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def get_sha():
|
| 163 |
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
| 164 |
+
|
| 165 |
+
def _run(command):
|
| 166 |
+
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
|
| 167 |
+
sha = 'N/A'
|
| 168 |
+
diff = "clean"
|
| 169 |
+
branch = 'N/A'
|
| 170 |
+
try:
|
| 171 |
+
sha = _run(['git', 'rev-parse', 'HEAD'])
|
| 172 |
+
subprocess.check_output(['git', 'diff'], cwd=cwd)
|
| 173 |
+
diff = _run(['git', 'diff-index', 'HEAD'])
|
| 174 |
+
diff = "has uncommited changes" if diff else "clean"
|
| 175 |
+
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
| 176 |
+
except Exception:
|
| 177 |
+
pass
|
| 178 |
+
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
| 179 |
+
return message
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def collate_fn(batch):
|
| 183 |
+
batch = list(zip(*batch))
|
| 184 |
+
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
| 185 |
+
return tuple(batch)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _max_by_axis(the_list):
|
| 189 |
+
# type: (List[List[int]]) -> List[int]
|
| 190 |
+
maxes = the_list[0]
|
| 191 |
+
for sublist in the_list[1:]:
|
| 192 |
+
for index, item in enumerate(sublist):
|
| 193 |
+
maxes[index] = max(maxes[index], item)
|
| 194 |
+
return maxes
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
| 198 |
+
# TODO make this more general
|
| 199 |
+
if tensor_list[0].ndim == 3:
|
| 200 |
+
# TODO make it support different-sized images
|
| 201 |
+
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
| 202 |
+
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
| 203 |
+
batch_shape = [len(tensor_list)] + max_size
|
| 204 |
+
b, c, h, w = batch_shape
|
| 205 |
+
dtype = tensor_list[0].dtype
|
| 206 |
+
device = tensor_list[0].device
|
| 207 |
+
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
| 208 |
+
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
| 209 |
+
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
| 210 |
+
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
| 211 |
+
m[: img.shape[1], :img.shape[2]] = False
|
| 212 |
+
else:
|
| 213 |
+
raise ValueError('not supported')
|
| 214 |
+
return NestedTensor(tensor, mask)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class NestedTensor(object):
|
| 218 |
+
def __init__(self, tensors, mask: Optional[Tensor]):
|
| 219 |
+
self.tensors = tensors
|
| 220 |
+
self.mask = mask
|
| 221 |
+
|
| 222 |
+
def to(self, device):
|
| 223 |
+
# type: (Device) -> NestedTensor # noqa
|
| 224 |
+
cast_tensor = self.tensors.to(device)
|
| 225 |
+
mask = self.mask
|
| 226 |
+
if mask is not None:
|
| 227 |
+
assert mask is not None
|
| 228 |
+
cast_mask = mask.to(device)
|
| 229 |
+
else:
|
| 230 |
+
cast_mask = None
|
| 231 |
+
return NestedTensor(cast_tensor, cast_mask)
|
| 232 |
+
|
| 233 |
+
def decompose(self):
|
| 234 |
+
return self.tensors, self.mask
|
| 235 |
+
|
| 236 |
+
def __repr__(self):
|
| 237 |
+
return str(self.tensors)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def setup_for_distributed(is_master):
|
| 241 |
+
"""
|
| 242 |
+
This function disables printing when not in master process
|
| 243 |
+
"""
|
| 244 |
+
import builtins as __builtin__
|
| 245 |
+
builtin_print = __builtin__.print
|
| 246 |
+
|
| 247 |
+
def print(*args, **kwargs):
|
| 248 |
+
force = kwargs.pop('force', False)
|
| 249 |
+
if is_master or force:
|
| 250 |
+
builtin_print(*args, **kwargs)
|
| 251 |
+
|
| 252 |
+
__builtin__.print = print
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def is_dist_avail_and_initialized():
|
| 256 |
+
if not dist.is_available():
|
| 257 |
+
return False
|
| 258 |
+
if not dist.is_initialized():
|
| 259 |
+
return False
|
| 260 |
+
return True
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def get_world_size():
|
| 264 |
+
if not is_dist_avail_and_initialized():
|
| 265 |
+
return 1
|
| 266 |
+
return dist.get_world_size()
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def get_rank():
|
| 270 |
+
if not is_dist_avail_and_initialized():
|
| 271 |
+
return 0
|
| 272 |
+
return dist.get_rank()
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def is_main_process():
|
| 276 |
+
return get_rank() == 0
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def save_on_master(*args, **kwargs):
|
| 280 |
+
if is_main_process():
|
| 281 |
+
torch.save(*args, **kwargs)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def _check_if_valid_ip(ip):
|
| 285 |
+
try:
|
| 286 |
+
socket.inet_aton(ip)
|
| 287 |
+
# legal
|
| 288 |
+
except socket.error:
|
| 289 |
+
# Not legal
|
| 290 |
+
return False
|
| 291 |
+
return True
|
| 292 |
+
|
| 293 |
+
def arg_as_list(s):
|
| 294 |
+
v = ast.literal_eval(s)
|
| 295 |
+
if type(v) is not list:
|
| 296 |
+
raise argparse.ArgumentTypeError("List should be given.")
|
| 297 |
+
return v
|
| 298 |
+
|
| 299 |
+
def _maybe_gethostbyname(addr):
|
| 300 |
+
"""to be compatible with Braincloud on which one can access the nodes by their task names.
|
| 301 |
+
Each node has to wait until all the tasks in the group are up on the cloud."""
|
| 302 |
+
if _check_if_valid_ip(addr):
|
| 303 |
+
# If IP address is given, do nothing
|
| 304 |
+
return addr
|
| 305 |
+
|
| 306 |
+
# Otherwise, find the IP address by hostname
|
| 307 |
+
done = False
|
| 308 |
+
retry = 0
|
| 309 |
+
print(f"Get URL by the given hostname '{addr}' in Braincloud..")
|
| 310 |
+
while not done:
|
| 311 |
+
try:
|
| 312 |
+
addr = socket.gethostbyname(addr)
|
| 313 |
+
done = True
|
| 314 |
+
except:
|
| 315 |
+
retry += 1
|
| 316 |
+
print(f"Retrying count: {retry}")
|
| 317 |
+
time.sleep(3)
|
| 318 |
+
print(f"Found the host by IP address: {addr}")
|
| 319 |
+
return addr
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def init_distributed_mode(args):
|
| 323 |
+
|
| 324 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
| 325 |
+
os.environ["MASTER_ADDR"] = _maybe_gethostbyname(os.environ["MASTER_ADDR"])
|
| 326 |
+
args.rank = int(os.environ["RANK"])
|
| 327 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
| 328 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
| 329 |
+
args.dist_url = 'env://'
|
| 330 |
+
os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count())
|
| 331 |
+
elif 'SLURM_PROCID' in os.environ:
|
| 332 |
+
proc_id = int(os.environ['SLURM_PROCID'])
|
| 333 |
+
ntasks = int(os.environ['SLURM_NTASKS'])
|
| 334 |
+
node_list = os.environ['SLURM_NODELIST']
|
| 335 |
+
num_gpus = torch.cuda.device_count()
|
| 336 |
+
addr = subprocess.getoutput(
|
| 337 |
+
'scontrol show hostname {} | head -n1'.format(node_list))
|
| 338 |
+
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500')
|
| 339 |
+
os.environ['MASTER_ADDR'] = addr
|
| 340 |
+
os.environ['WORLD_SIZE'] = str(ntasks)
|
| 341 |
+
os.environ['RANK'] = str(proc_id)
|
| 342 |
+
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
|
| 343 |
+
os.environ['LOCAL_SIZE'] = str(num_gpus)
|
| 344 |
+
args.dist_url = 'env://'
|
| 345 |
+
args.world_size = ntasks
|
| 346 |
+
args.rank = proc_id
|
| 347 |
+
args.gpu = proc_id % num_gpus
|
| 348 |
+
else:
|
| 349 |
+
print('Not using distributed mode')
|
| 350 |
+
args.distributed = False
|
| 351 |
+
return
|
| 352 |
+
|
| 353 |
+
args.distributed = True
|
| 354 |
+
|
| 355 |
+
torch.cuda.set_device(args.gpu)
|
| 356 |
+
args.dist_backend = 'nccl'
|
| 357 |
+
print('| distributed init (rank {}): {}'.format(
|
| 358 |
+
args.rank, args.dist_url), flush=True)
|
| 359 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
| 360 |
+
world_size=args.world_size, rank=args.rank)
|
| 361 |
+
torch.distributed.barrier()
|
| 362 |
+
setup_for_distributed(args.rank == 0)
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
@torch.no_grad()
|
| 366 |
+
def accuracy(output, target, topk=(1,)):
|
| 367 |
+
"""Computes the precision@k for the specified values of k"""
|
| 368 |
+
if target.numel() == 0:
|
| 369 |
+
return [torch.zeros([], device=output.device)]
|
| 370 |
+
maxk = max(topk)
|
| 371 |
+
batch_size = target.size(0)
|
| 372 |
+
|
| 373 |
+
_, pred = output.topk(maxk, 1, True, True)
|
| 374 |
+
pred = pred.t()
|
| 375 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
| 376 |
+
|
| 377 |
+
res = []
|
| 378 |
+
for k in topk:
|
| 379 |
+
correct_k = correct[:k].view(-1).float().sum(0)
|
| 380 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
| 381 |
+
return res
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
| 385 |
+
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
| 386 |
+
"""
|
| 387 |
+
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
| 388 |
+
This will eventually be supported natively by PyTorch, and this
|
| 389 |
+
class can go away.
|
| 390 |
+
"""
|
| 391 |
+
if float(torchvision.__version__[:3]) < 0.7:
|
| 392 |
+
if input.numel() > 0:
|
| 393 |
+
return torch.nn.functional.interpolate(
|
| 394 |
+
input, size, scale_factor, mode, align_corners
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
output_shape = _output_size(2, input, size, scale_factor)
|
| 398 |
+
output_shape = list(input.shape[:-2]) + list(output_shape)
|
| 399 |
+
return _new_empty_tensor(input, output_shape)
|
| 400 |
+
else:
|
| 401 |
+
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
hotr/util/ramp.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2018, Curious AI Ltd. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
| 4 |
+
# 4.0 International License. To view a copy of this license, visit
|
| 5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
| 6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
def sigmoid_rampup(current, rampup_length,max_coef=1.):
|
| 11 |
+
"""Exponential rampup from https://arxiv.org/abs/1610.02242"""
|
| 12 |
+
"""Modified version from https://github.com/vikasverma1077/GraphMix/blob/master/semisupervised/codes/ramps.py"""
|
| 13 |
+
if rampup_length == 0:
|
| 14 |
+
return max_coef
|
| 15 |
+
else:
|
| 16 |
+
current = np.clip(current, 0.0, rampup_length)
|
| 17 |
+
phase = 1.0 - current / rampup_length
|
| 18 |
+
return float(np.exp(-5.0 * phase * phase))*max_coef
|
| 19 |
+
|
| 20 |
+
def cosine_rampdown(current, rampdown_length,max_coef=1.):
|
| 21 |
+
"""Cosine rampdown from https://arxiv.org/abs/1608.03983"""
|
| 22 |
+
assert 0 <= current <= rampdown_length
|
| 23 |
+
return float(.5 * (np.cos(np.pi *current / rampdown_length) + 1))*max_coef
|
imgs/mainfig.png
ADDED
|
main.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# HOTR official code : main.py
|
| 3 |
+
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
|
| 4 |
+
# ------------------------------------------------------------------------
|
| 5 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
| 6 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 7 |
+
# ------------------------------------------------------------------------
|
| 8 |
+
import argparse
|
| 9 |
+
import datetime
|
| 10 |
+
import json
|
| 11 |
+
import random
|
| 12 |
+
import time
|
| 13 |
+
import multiprocessing
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
| 19 |
+
|
| 20 |
+
import hotr.data.datasets as datasets
|
| 21 |
+
import hotr.util.misc as utils
|
| 22 |
+
from hotr.engine.arg_parser import get_args_parser
|
| 23 |
+
from hotr.data.datasets import build_dataset, get_coco_api_from_dataset
|
| 24 |
+
from hotr.engine.trainer import train_one_epoch
|
| 25 |
+
from hotr.engine import hoi_evaluator, hoi_accumulator
|
| 26 |
+
from hotr.models import build_model
|
| 27 |
+
import wandb
|
| 28 |
+
|
| 29 |
+
from hotr.util.logger import print_params, print_args
|
| 30 |
+
|
| 31 |
+
def save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename):
|
| 32 |
+
# save_ckpt: function for saving checkpoints
|
| 33 |
+
output_dir = Path(args.output_dir)
|
| 34 |
+
if args.output_dir:
|
| 35 |
+
checkpoint_path = output_dir / f'{filename}.pth'
|
| 36 |
+
utils.save_on_master({
|
| 37 |
+
'model': model_without_ddp.state_dict(),
|
| 38 |
+
'optimizer': optimizer.state_dict(),
|
| 39 |
+
'lr_scheduler': lr_scheduler.state_dict(),
|
| 40 |
+
'epoch': epoch,
|
| 41 |
+
'args': args,
|
| 42 |
+
}, checkpoint_path)
|
| 43 |
+
|
| 44 |
+
def main(args):
|
| 45 |
+
utils.init_distributed_mode(args)
|
| 46 |
+
|
| 47 |
+
if args.frozen_weights is not None:
|
| 48 |
+
print("Freeze weights for detector")
|
| 49 |
+
|
| 50 |
+
device = torch.device(args.device)
|
| 51 |
+
|
| 52 |
+
# fix the seed for reproducibility
|
| 53 |
+
seed = args.seed + utils.get_rank()
|
| 54 |
+
torch.manual_seed(seed)
|
| 55 |
+
np.random.seed(seed)
|
| 56 |
+
random.seed(seed)
|
| 57 |
+
|
| 58 |
+
# Data Setup
|
| 59 |
+
dataset_train = build_dataset(image_set='train', args=args)
|
| 60 |
+
dataset_val = build_dataset(image_set='val' if not args.eval else 'test', args=args)
|
| 61 |
+
assert dataset_train.num_action() == dataset_val.num_action(), "Number of actions should be the same between splits"
|
| 62 |
+
args.num_classes = dataset_train.num_category()
|
| 63 |
+
args.num_actions = dataset_train.num_action()
|
| 64 |
+
args.action_names = dataset_train.get_actions()
|
| 65 |
+
if args.share_enc: args.hoi_enc_layers = args.enc_layers
|
| 66 |
+
if args.pretrained_dec: args.hoi_dec_layers = args.dec_layers
|
| 67 |
+
if args.dataset_file == 'vcoco':
|
| 68 |
+
# Save V-COCO dataset statistics
|
| 69 |
+
args.valid_ids = np.array(dataset_train.get_object_label_idx()).nonzero()[0]
|
| 70 |
+
args.invalid_ids = np.argwhere(np.array(dataset_train.get_object_label_idx()) == 0).squeeze(1)
|
| 71 |
+
args.human_actions = dataset_train.get_human_action()
|
| 72 |
+
args.object_actions = dataset_train.get_object_action()
|
| 73 |
+
args.num_human_act = dataset_train.num_human_act()
|
| 74 |
+
elif args.dataset_file == 'hico-det':
|
| 75 |
+
args.valid_obj_ids = dataset_train.get_valid_obj_ids()
|
| 76 |
+
print_args(args)
|
| 77 |
+
|
| 78 |
+
if args.distributed:
|
| 79 |
+
sampler_train = DistributedSampler(dataset_train, shuffle=True)
|
| 80 |
+
sampler_val = DistributedSampler(dataset_val, shuffle=False)
|
| 81 |
+
else:
|
| 82 |
+
sampler_train = torch.utils.data.RandomSampler(dataset_train)
|
| 83 |
+
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
| 84 |
+
|
| 85 |
+
batch_sampler_train = torch.utils.data.BatchSampler(
|
| 86 |
+
sampler_train, args.batch_size, drop_last=True)
|
| 87 |
+
|
| 88 |
+
data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
|
| 89 |
+
collate_fn=utils.collate_fn, num_workers=args.num_workers)
|
| 90 |
+
data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
|
| 91 |
+
drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers)
|
| 92 |
+
|
| 93 |
+
# Model Setup
|
| 94 |
+
model, criterion, postprocessors = build_model(args)
|
| 95 |
+
# import pdb;pdb.set_trace()
|
| 96 |
+
model.to(device)
|
| 97 |
+
|
| 98 |
+
model_without_ddp = model
|
| 99 |
+
if args.distributed:
|
| 100 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
| 101 |
+
model_without_ddp = model.module
|
| 102 |
+
n_parameters = print_params(model)
|
| 103 |
+
|
| 104 |
+
param_dicts = [
|
| 105 |
+
{"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]},
|
| 106 |
+
{
|
| 107 |
+
"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad],
|
| 108 |
+
"lr": args.lr_backbone,
|
| 109 |
+
},
|
| 110 |
+
]
|
| 111 |
+
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay)
|
| 112 |
+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
|
| 113 |
+
# lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [1,100])
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# Weight Setup
|
| 117 |
+
if args.frozen_weights is not None:
|
| 118 |
+
if args.frozen_weights.startswith('https'):
|
| 119 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 120 |
+
args.frozen_weights, map_location='cpu', check_hash=True)
|
| 121 |
+
else:
|
| 122 |
+
checkpoint = torch.load(args.frozen_weights, map_location='cpu')
|
| 123 |
+
model_without_ddp.detr.load_state_dict(checkpoint['model'])
|
| 124 |
+
|
| 125 |
+
if args.resume:
|
| 126 |
+
if args.resume.startswith('https'):
|
| 127 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 128 |
+
args.resume, map_location='cpu', check_hash=True)
|
| 129 |
+
else:
|
| 130 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
| 131 |
+
model_without_ddp.load_state_dict(checkpoint['model'])
|
| 132 |
+
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
|
| 133 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 134 |
+
# lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
|
| 135 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
| 136 |
+
# import pdb;pdb.set_trace()
|
| 137 |
+
if args.eval:
|
| 138 |
+
# test only mode
|
| 139 |
+
if args.HOIDet:
|
| 140 |
+
if args.dataset_file == 'vcoco':
|
| 141 |
+
total_res = hoi_evaluator(args, model, criterion, postprocessors, data_loader_val, device)
|
| 142 |
+
sc1, sc2 = hoi_accumulator(args, total_res, True, False)
|
| 143 |
+
elif args.dataset_file == 'hico-det':
|
| 144 |
+
test_stats = hoi_evaluator(args, model, None, postprocessors, data_loader_val, device)
|
| 145 |
+
print(f'| mAP (full)\t\t: {test_stats["mAP"]:.2f}')
|
| 146 |
+
print(f'| mAP (rare)\t\t: {test_stats["mAP rare"]:.2f}')
|
| 147 |
+
print(f'| mAP (non-rare)\t: {test_stats["mAP non-rare"]:.2f}')
|
| 148 |
+
else: raise ValueError(f'dataset {args.dataset_file} is not supported.')
|
| 149 |
+
return
|
| 150 |
+
else:
|
| 151 |
+
test_stats, coco_evaluator = evaluate_coco(model, criterion, postprocessors,
|
| 152 |
+
data_loader_val, base_ds, device, args.output_dir)
|
| 153 |
+
if args.output_dir:
|
| 154 |
+
utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth")
|
| 155 |
+
return
|
| 156 |
+
|
| 157 |
+
# stats
|
| 158 |
+
scenario1, scenario2 = 0, 0
|
| 159 |
+
best_mAP, best_rare, best_non_rare = 0, 0, 0
|
| 160 |
+
|
| 161 |
+
# add argparse
|
| 162 |
+
if args.wandb and utils.get_rank() == 0:
|
| 163 |
+
wandb.init(
|
| 164 |
+
project=args.project_name,
|
| 165 |
+
group=args.group_name,
|
| 166 |
+
name=args.run_name,
|
| 167 |
+
config=args
|
| 168 |
+
)
|
| 169 |
+
wandb.watch(model)
|
| 170 |
+
|
| 171 |
+
# Training starts here!
|
| 172 |
+
# lr_scheduler.step()
|
| 173 |
+
start_time = time.time()
|
| 174 |
+
for epoch in range(args.start_epoch, args.epochs):
|
| 175 |
+
if args.distributed:
|
| 176 |
+
sampler_train.set_epoch(epoch)
|
| 177 |
+
train_stats = train_one_epoch(
|
| 178 |
+
model, criterion, data_loader_train, optimizer, device, epoch, args.epochs, args.ramp_up_epoch,args.ramp_down_epoch,args.hoi_consistency_loss_coef,
|
| 179 |
+
args.clip_max_norm, dataset_file=args.dataset_file, log=args.wandb)
|
| 180 |
+
lr_scheduler.step()
|
| 181 |
+
|
| 182 |
+
# Validation
|
| 183 |
+
if args.validate:
|
| 184 |
+
print('-'*100)
|
| 185 |
+
if args.dataset_file == 'vcoco':
|
| 186 |
+
total_res = hoi_evaluator(args, model, criterion, postprocessors, data_loader_val, device)
|
| 187 |
+
if utils.get_rank() == 0:
|
| 188 |
+
sc1, sc2 = hoi_accumulator(args, total_res, False, args.wandb)
|
| 189 |
+
if sc1 > scenario1:
|
| 190 |
+
scenario1 = sc1
|
| 191 |
+
scenario2 = sc2
|
| 192 |
+
save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename='best')
|
| 193 |
+
print(f'| Scenario #1 mAP : {sc1:.2f} ({scenario1:.2f})')
|
| 194 |
+
print(f'| Scenario #2 mAP : {sc2:.2f} ({scenario2:.2f})')
|
| 195 |
+
elif args.dataset_file == 'hico-det':
|
| 196 |
+
test_stats = hoi_evaluator(args, model, None, postprocessors, data_loader_val, device)
|
| 197 |
+
if utils.get_rank() == 0:
|
| 198 |
+
if test_stats['mAP'] > best_mAP:
|
| 199 |
+
best_mAP = test_stats['mAP']
|
| 200 |
+
best_rare = test_stats['mAP rare']
|
| 201 |
+
best_non_rare = test_stats['mAP non-rare']
|
| 202 |
+
save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename='best')
|
| 203 |
+
print(f'| mAP (full)\t\t: {test_stats["mAP"]:.2f} ({best_mAP:.2f})')
|
| 204 |
+
print(f'| mAP (rare)\t\t: {test_stats["mAP rare"]:.2f} ({best_rare:.2f})')
|
| 205 |
+
print(f'| mAP (non-rare)\t: {test_stats["mAP non-rare"]:.2f} ({best_non_rare:.2f})')
|
| 206 |
+
if args.wandb and utils.get_rank() == 0:
|
| 207 |
+
wandb.log({
|
| 208 |
+
'mAP': test_stats['mAP'],
|
| 209 |
+
'mAP rare': test_stats['mAP rare'],
|
| 210 |
+
'mAP non-rare': test_stats['mAP non-rare']
|
| 211 |
+
})
|
| 212 |
+
print('-'*100)
|
| 213 |
+
|
| 214 |
+
save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename='checkpoint')
|
| 215 |
+
if (epoch + 1) % args.lr_drop == 0 :
|
| 216 |
+
save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename='checkpoint_'+str(epoch))
|
| 217 |
+
# if (epoch + 1) % args.pseudo_epoch == 0 :
|
| 218 |
+
# save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename='checkpoint_pseudo_'+str(epoch))
|
| 219 |
+
total_time = time.time() - start_time
|
| 220 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 221 |
+
print('Training time {}'.format(total_time_str))
|
| 222 |
+
if args.dataset_file == 'vcoco':
|
| 223 |
+
print(f'| Scenario #1 mAP : {scenario1:.2f}')
|
| 224 |
+
print(f'| Scenario #2 mAP : {scenario2:.2f}')
|
| 225 |
+
elif args.dataset_file == 'hico-det':
|
| 226 |
+
print(f'| mAP (full)\t\t: {best_mAP:.2f}')
|
| 227 |
+
print(f'| mAP (rare)\t\t: {best_rare:.2f}')
|
| 228 |
+
print(f'| mAP (non-rare)\t: {best_non_rare:.2f}')
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
if __name__ == '__main__':
|
| 232 |
+
parser = argparse.ArgumentParser(
|
| 233 |
+
'End-to-End Human Object Interaction training and evaluation script',
|
| 234 |
+
parents=[get_args_parser()]
|
| 235 |
+
)
|
| 236 |
+
args = parser.parse_args()
|
| 237 |
+
if args.output_dir:
|
| 238 |
+
args.output_dir += f"/{args.group_name}/{args.run_name}/"
|
| 239 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 240 |
+
main(args)
|
tools/launch.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------------------------------------------------------------------------
|
| 2 |
+
# Deformable DETR
|
| 3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------------------------------------------------------------------------
|
| 6 |
+
# Modified from https://github.com/pytorch/pytorch/blob/173f224570017b4b1a3a1a13d0bff280a54d9cd9/torch/distributed/launch.py
|
| 7 |
+
# --------------------------------------------------------------------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
r"""
|
| 10 |
+
`torch.distributed.launch` is a module that spawns up multiple distributed
|
| 11 |
+
training processes on each of the training nodes.
|
| 12 |
+
The utility can be used for single-node distributed training, in which one or
|
| 13 |
+
more processes per node will be spawned. The utility can be used for either
|
| 14 |
+
CPU training or GPU training. If the utility is used for GPU training,
|
| 15 |
+
each distributed process will be operating on a single GPU. This can achieve
|
| 16 |
+
well-improved single-node training performance. It can also be used in
|
| 17 |
+
multi-node distributed training, by spawning up multiple processes on each node
|
| 18 |
+
for well-improved multi-node distributed training performance as well.
|
| 19 |
+
This will especially be benefitial for systems with multiple Infiniband
|
| 20 |
+
interfaces that have direct-GPU support, since all of them can be utilized for
|
| 21 |
+
aggregated communication bandwidth.
|
| 22 |
+
In both cases of single-node distributed training or multi-node distributed
|
| 23 |
+
training, this utility will launch the given number of processes per node
|
| 24 |
+
(``--nproc_per_node``). If used for GPU training, this number needs to be less
|
| 25 |
+
or euqal to the number of GPUs on the current system (``nproc_per_node``),
|
| 26 |
+
and each process will be operating on a single GPU from *GPU 0 to
|
| 27 |
+
GPU (nproc_per_node - 1)*.
|
| 28 |
+
**How to use this module:**
|
| 29 |
+
1. Single-Node multi-process distributed training
|
| 30 |
+
::
|
| 31 |
+
>>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE
|
| 32 |
+
YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other
|
| 33 |
+
arguments of your training script)
|
| 34 |
+
2. Multi-Node multi-process distributed training: (e.g. two nodes)
|
| 35 |
+
Node 1: *(IP: 192.168.1.1, and has a free port: 1234)*
|
| 36 |
+
::
|
| 37 |
+
>>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE
|
| 38 |
+
--nnodes=2 --node_rank=0 --master_addr="192.168.1.1"
|
| 39 |
+
--master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
|
| 40 |
+
and all other arguments of your training script)
|
| 41 |
+
Node 2:
|
| 42 |
+
::
|
| 43 |
+
>>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE
|
| 44 |
+
--nnodes=2 --node_rank=1 --master_addr="192.168.1.1"
|
| 45 |
+
--master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
|
| 46 |
+
and all other arguments of your training script)
|
| 47 |
+
3. To look up what optional arguments this module offers:
|
| 48 |
+
::
|
| 49 |
+
>>> python -m torch.distributed.launch --help
|
| 50 |
+
**Important Notices:**
|
| 51 |
+
1. This utilty and multi-process distributed (single-node or
|
| 52 |
+
multi-node) GPU training currently only achieves the best performance using
|
| 53 |
+
the NCCL distributed backend. Thus NCCL backend is the recommended backend to
|
| 54 |
+
use for GPU training.
|
| 55 |
+
2. In your training program, you must parse the command-line argument:
|
| 56 |
+
``--local_rank=LOCAL_PROCESS_RANK``, which will be provided by this module.
|
| 57 |
+
If your training program uses GPUs, you should ensure that your code only
|
| 58 |
+
runs on the GPU device of LOCAL_PROCESS_RANK. This can be done by:
|
| 59 |
+
Parsing the local_rank argument
|
| 60 |
+
::
|
| 61 |
+
>>> import argparse
|
| 62 |
+
>>> parser = argparse.ArgumentParser()
|
| 63 |
+
>>> parser.add_argument("--local_rank", type=int)
|
| 64 |
+
>>> args = parser.parse_args()
|
| 65 |
+
Set your device to local rank using either
|
| 66 |
+
::
|
| 67 |
+
>>> torch.cuda.set_device(arg.local_rank) # before your code runs
|
| 68 |
+
or
|
| 69 |
+
::
|
| 70 |
+
>>> with torch.cuda.device(arg.local_rank):
|
| 71 |
+
>>> # your code to run
|
| 72 |
+
3. In your training program, you are supposed to call the following function
|
| 73 |
+
at the beginning to start the distributed backend. You need to make sure that
|
| 74 |
+
the init_method uses ``env://``, which is the only supported ``init_method``
|
| 75 |
+
by this module.
|
| 76 |
+
::
|
| 77 |
+
torch.distributed.init_process_group(backend='YOUR BACKEND',
|
| 78 |
+
init_method='env://')
|
| 79 |
+
4. In your training program, you can either use regular distributed functions
|
| 80 |
+
or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your
|
| 81 |
+
training program uses GPUs for training and you would like to use
|
| 82 |
+
:func:`torch.nn.parallel.DistributedDataParallel` module,
|
| 83 |
+
here is how to configure it.
|
| 84 |
+
::
|
| 85 |
+
model = torch.nn.parallel.DistributedDataParallel(model,
|
| 86 |
+
device_ids=[arg.local_rank],
|
| 87 |
+
output_device=arg.local_rank)
|
| 88 |
+
Please ensure that ``device_ids`` argument is set to be the only GPU device id
|
| 89 |
+
that your code will be operating on. This is generally the local rank of the
|
| 90 |
+
process. In other words, the ``device_ids`` needs to be ``[args.local_rank]``,
|
| 91 |
+
and ``output_device`` needs to be ``args.local_rank`` in order to use this
|
| 92 |
+
utility
|
| 93 |
+
5. Another way to pass ``local_rank`` to the subprocesses via environment variable
|
| 94 |
+
``LOCAL_RANK``. This behavior is enabled when you launch the script with
|
| 95 |
+
``--use_env=True``. You must adjust the subprocess example above to replace
|
| 96 |
+
``args.local_rank`` with ``os.environ['LOCAL_RANK']``; the launcher
|
| 97 |
+
will not pass ``--local_rank`` when you specify this flag.
|
| 98 |
+
.. warning::
|
| 99 |
+
``local_rank`` is NOT globally unique: it is only unique per process
|
| 100 |
+
on a machine. Thus, don't use it to decide if you should, e.g.,
|
| 101 |
+
write to a networked filesystem. See
|
| 102 |
+
https://github.com/pytorch/pytorch/issues/12042 for an example of
|
| 103 |
+
how things can go wrong if you don't do this correctly.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
import sys
|
| 108 |
+
import subprocess
|
| 109 |
+
import os
|
| 110 |
+
import socket
|
| 111 |
+
from argparse import ArgumentParser, REMAINDER
|
| 112 |
+
|
| 113 |
+
import torch
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def parse_args():
|
| 117 |
+
"""
|
| 118 |
+
Helper function parsing the command line options
|
| 119 |
+
@retval ArgumentParser
|
| 120 |
+
"""
|
| 121 |
+
parser = ArgumentParser(description="PyTorch distributed training launch "
|
| 122 |
+
"helper utilty that will spawn up "
|
| 123 |
+
"multiple distributed processes")
|
| 124 |
+
|
| 125 |
+
# Optional arguments for the launch helper
|
| 126 |
+
parser.add_argument("--nnodes", type=int, default=1,
|
| 127 |
+
help="The number of nodes to use for distributed "
|
| 128 |
+
"training")
|
| 129 |
+
parser.add_argument("--node_rank", type=int, default=0,
|
| 130 |
+
help="The rank of the node for multi-node distributed "
|
| 131 |
+
"training")
|
| 132 |
+
parser.add_argument("--nproc_per_node", type=int, default=1,
|
| 133 |
+
help="The number of processes to launch on each node, "
|
| 134 |
+
"for GPU training, this is recommended to be set "
|
| 135 |
+
"to the number of GPUs in your system so that "
|
| 136 |
+
"each process can be bound to a single GPU.")
|
| 137 |
+
parser.add_argument("--master_addr", default="127.0.0.1", type=str,
|
| 138 |
+
help="Master node (rank 0)'s address, should be either "
|
| 139 |
+
"the IP address or the hostname of node 0, for "
|
| 140 |
+
"single node multi-proc training, the "
|
| 141 |
+
"--master_addr can simply be 127.0.0.1")
|
| 142 |
+
parser.add_argument("--master_port", default=29500, type=int,
|
| 143 |
+
help="Master node (rank 0)'s free port that needs to "
|
| 144 |
+
"be used for communciation during distributed "
|
| 145 |
+
"training")
|
| 146 |
+
|
| 147 |
+
# positional
|
| 148 |
+
parser.add_argument("training_script", type=str,
|
| 149 |
+
help="The full path to the single GPU training "
|
| 150 |
+
"program/script to be launched in parallel, "
|
| 151 |
+
"followed by all the arguments for the "
|
| 152 |
+
"training script")
|
| 153 |
+
|
| 154 |
+
# rest from the training program
|
| 155 |
+
parser.add_argument('training_script_args', nargs=REMAINDER)
|
| 156 |
+
return parser.parse_args()
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def main():
|
| 160 |
+
args = parse_args()
|
| 161 |
+
|
| 162 |
+
# world size in terms of number of processes
|
| 163 |
+
dist_world_size = args.nproc_per_node * args.nnodes
|
| 164 |
+
|
| 165 |
+
# set PyTorch distributed related environmental variables
|
| 166 |
+
current_env = os.environ.copy()
|
| 167 |
+
current_env["MASTER_ADDR"] = args.master_addr
|
| 168 |
+
current_env["MASTER_PORT"] = str(args.master_port)
|
| 169 |
+
current_env["WORLD_SIZE"] = str(dist_world_size)
|
| 170 |
+
|
| 171 |
+
processes = []
|
| 172 |
+
|
| 173 |
+
for local_rank in range(0, args.nproc_per_node):
|
| 174 |
+
# each process's rank
|
| 175 |
+
dist_rank = args.nproc_per_node * args.node_rank + local_rank
|
| 176 |
+
current_env["RANK"] = str(dist_rank)
|
| 177 |
+
current_env["LOCAL_RANK"] = str(local_rank)
|
| 178 |
+
|
| 179 |
+
cmd = [args.training_script] + args.training_script_args
|
| 180 |
+
|
| 181 |
+
process = subprocess.Popen(cmd, env=current_env)
|
| 182 |
+
processes.append(process)
|
| 183 |
+
|
| 184 |
+
for process in processes:
|
| 185 |
+
process.wait()
|
| 186 |
+
if process.returncode != 0:
|
| 187 |
+
raise subprocess.CalledProcessError(returncode=process.returncode,
|
| 188 |
+
cmd=process.args)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
if __name__ == "__main__":
|
| 192 |
+
main()
|
tools/run_dist_launch.sh
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# ------------------------------------------------------------------------
|
| 3 |
+
# Deformable DETR
|
| 4 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
# ------------------------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
set -x
|
| 9 |
+
|
| 10 |
+
GPUS=$1
|
| 11 |
+
RUN_COMMAND=${@:2}
|
| 12 |
+
if [ $GPUS -lt 8 ]; then
|
| 13 |
+
GPUS_PER_NODE=${GPUS_PER_NODE:-$GPUS}
|
| 14 |
+
else
|
| 15 |
+
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
|
| 16 |
+
fi
|
| 17 |
+
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
|
| 18 |
+
MASTER_PORT=${MASTER_PORT:-"29500"}
|
| 19 |
+
NODE_RANK=${NODE_RANK:-0}
|
| 20 |
+
|
| 21 |
+
let "NNODES=GPUS/GPUS_PER_NODE"
|
| 22 |
+
|
| 23 |
+
python ./tools/launch.py \
|
| 24 |
+
--nnodes ${NNODES} \
|
| 25 |
+
--node_rank ${NODE_RANK} \
|
| 26 |
+
--master_addr ${MASTER_ADDR} \
|
| 27 |
+
--master_port ${MASTER_PORT} \
|
| 28 |
+
--nproc_per_node ${GPUS_PER_NODE} \
|
| 29 |
+
${RUN_COMMAND}
|
tools/run_dist_slurm.sh
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# --------------------------------------------------------------------------------------------------------------------------
|
| 3 |
+
# Deformable DETR
|
| 4 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
| 6 |
+
# --------------------------------------------------------------------------------------------------------------------------
|
| 7 |
+
# Modified from https://github.com/open-mmlab/mmdetection/blob/3b53fe15d87860c6941f3dda63c0f27422da6266/tools/slurm_train.sh
|
| 8 |
+
# --------------------------------------------------------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
set -x
|
| 11 |
+
|
| 12 |
+
PARTITION=$1
|
| 13 |
+
JOB_NAME=$2
|
| 14 |
+
GPUS=$3
|
| 15 |
+
RUN_COMMAND=${@:4}
|
| 16 |
+
if [ $GPUS -lt 8 ]; then
|
| 17 |
+
GPUS_PER_NODE=${GPUS_PER_NODE:-$GPUS}
|
| 18 |
+
else
|
| 19 |
+
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
|
| 20 |
+
fi
|
| 21 |
+
CPUS_PER_TASK=${CPUS_PER_TASK:-4}
|
| 22 |
+
SRUN_ARGS=${SRUN_ARGS:-""}
|
| 23 |
+
|
| 24 |
+
srun -p ${PARTITION} \
|
| 25 |
+
--job-name=${JOB_NAME} \
|
| 26 |
+
--gres=gpu:${GPUS_PER_NODE} \
|
| 27 |
+
--ntasks=${GPUS} \
|
| 28 |
+
--ntasks-per-node=${GPUS_PER_NODE} \
|
| 29 |
+
--cpus-per-task=${CPUS_PER_TASK} \
|
| 30 |
+
--kill-on-bad-exit=1 \
|
| 31 |
+
${SRUN_ARGS} \
|
| 32 |
+
${RUN_COMMAND}
|
| 33 |
+
|
v-coco
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
/data/public/rw/datasets/v-coco/
|