deep
Browse files- .gitattributes +2 -0
- .gitignore +11 -0
- Dockerfile +55 -0
- LICENSE +201 -0
- README.md +207 -10
- config.yaml +4 -0
- convert_tf_to_pt.sh +6 -0
- copy_weights.py +36 -0
- datasets.py +162 -0
- detect_faces_on_videos.py +82 -0
- dsfacedetector/__init__.py +0 -0
- dsfacedetector/data/__init__.py +0 -0
- dsfacedetector/data/config.py +57 -0
- dsfacedetector/face_ssd_infer.py +156 -0
- dsfacedetector/layers/__init__.py +3 -0
- dsfacedetector/layers/detection.py +157 -0
- dsfacedetector/layers/modules.py +98 -0
- dsfacedetector/layers/prior_box.py +133 -0
- dsfacedetector/utils.py +101 -0
- external_data/convert_tf_to_pt.py +174 -0
- external_data/original_tf/__init__.py +0 -0
- external_data/original_tf/efficientnet_builder.py +329 -0
- external_data/original_tf/efficientnet_model.py +713 -0
- external_data/original_tf/eval_ckpt_main.py +221 -0
- external_data/original_tf/preprocessing.py +241 -0
- external_data/original_tf/utils.py +405 -0
- extract_tracks_from_videos.py +105 -0
- generate_aligned_tracks.py +99 -0
- generate_track_pairs.py +70 -0
- generate_tracks.py +70 -0
- images/augmented_mixup.jpg +3 -0
- images/clip_example.jpg +0 -0
- images/first_and_second_model_inputs.jpg +0 -0
- images/mixup_example.jpg +3 -0
- images/pred_transform.jpg +0 -0
- images/third_model_input.jpg +0 -0
- models/.gitkeep +0 -0
- predict.py +399 -0
- tracker/__init__.py +0 -0
- tracker/iou_tracker.py +58 -0
- tracker/utils.py +35 -0
- train_b7_ns_aa_original_large_crop_100k.py +257 -0
- train_b7_ns_aa_original_re_100k.py +266 -0
- train_b7_ns_seq_aa_original_100k.py +281 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
images/augmented_mixup.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
images/mixup_example.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PyCharm
|
| 2 |
+
.idea
|
| 3 |
+
|
| 4 |
+
# Jupyter Notebook
|
| 5 |
+
.ipynb_checkpoints
|
| 6 |
+
|
| 7 |
+
# Byte-compiled / optimized / DLL files
|
| 8 |
+
__pycache__/
|
| 9 |
+
*.py[cod]
|
| 10 |
+
*$py.class
|
| 11 |
+
|
Dockerfile
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04
|
| 2 |
+
|
| 3 |
+
SHELL ["/bin/bash", "-c"]
|
| 4 |
+
|
| 5 |
+
RUN rm /etc/apt/sources.list.d/cuda.list \
|
| 6 |
+
/etc/apt/sources.list.d/nvidia-ml.list && \
|
| 7 |
+
apt-get update && \
|
| 8 |
+
DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
| 9 |
+
software-properties-common \
|
| 10 |
+
wget \
|
| 11 |
+
git && \
|
| 12 |
+
add-apt-repository -y ppa:deadsnakes/ppa && \
|
| 13 |
+
apt-get update && \
|
| 14 |
+
DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
| 15 |
+
python3.6 \
|
| 16 |
+
python3.6-dev && \
|
| 17 |
+
wget -O ~/get-pip.py \
|
| 18 |
+
https://bootstrap.pypa.io/get-pip.py && \
|
| 19 |
+
python3.6 ~/get-pip.py && \
|
| 20 |
+
pip3 --no-cache-dir install \
|
| 21 |
+
numpy==1.17.4 \
|
| 22 |
+
PyYAML==5.1.2 \
|
| 23 |
+
mkl==2019.0 \
|
| 24 |
+
mkl-include==2019.0 \
|
| 25 |
+
cmake==3.15.3 \
|
| 26 |
+
cffi==1.13.2 \
|
| 27 |
+
typing==3.7.4.1 \
|
| 28 |
+
six==1.13.0 \
|
| 29 |
+
Pillow==6.2.1 \
|
| 30 |
+
scipy==1.4.1 && \
|
| 31 |
+
cd /tmp && \
|
| 32 |
+
git clone https://github.com/pytorch/pytorch.git && \
|
| 33 |
+
cd pytorch && \
|
| 34 |
+
git checkout v1.3.0 && \
|
| 35 |
+
git submodule update --init --recursive && \
|
| 36 |
+
python3.6 setup.py install && \
|
| 37 |
+
cd /tmp && \
|
| 38 |
+
git clone https://github.com/pytorch/vision.git && \
|
| 39 |
+
cd vision && \
|
| 40 |
+
git checkout v0.4.1 && \
|
| 41 |
+
python3.6 setup.py install && \
|
| 42 |
+
DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
| 43 |
+
ffmpeg && \
|
| 44 |
+
pip3 --no-cache-dir install \
|
| 45 |
+
opencv-python==4.1.2.30 \
|
| 46 |
+
albumentations==0.4.3 \
|
| 47 |
+
tqdm==4.39.0 \
|
| 48 |
+
timm==0.1.18 \
|
| 49 |
+
efficientnet-pytorch==0.6.3 \
|
| 50 |
+
ffmpeg-python==0.2.0 \
|
| 51 |
+
tensorflow==1.15.2 && \
|
| 52 |
+
cd / && \
|
| 53 |
+
apt-get clean && \
|
| 54 |
+
apt-get autoremove && \
|
| 55 |
+
rm -rf /var/lib/apt/lists/* /tmp/*
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 2020 N-TECH.LAB LTD
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,10 +1,207 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Deepfake Detection Challenge
|
| 2 |
+
Solution for the [Deepfake Detection Challenge](https://www.kaggle.com/c/deepfake-detection-challenge).
|
| 3 |
+
Private LB score: **0.43452**
|
| 4 |
+
## Solution description
|
| 5 |
+
### Summary
|
| 6 |
+
Our solution consists of three EfficientNet-B7 models (we used the Noisy Student pre-trained weights). We did not use
|
| 7 |
+
external data, except for pre-trained weights. One model runs on frame sequences (a 3D convolution has been added to
|
| 8 |
+
each EfficientNet-B7 block). The other two models work frame-by-frame and differ in the size of the face crop and
|
| 9 |
+
augmentations during training. To tackle overfitting problem, we used mixup technique on aligned real-fake pairs. In
|
| 10 |
+
addition, we used the following augmentations: AutoAugment, Random Erasing, Random Crops, Random Flips, and various
|
| 11 |
+
video compression parameters. Video compression augmentation was done on-the-fly. To do this, short cropped tracks (50
|
| 12 |
+
frames each) were saved in PNG format, and at each training iteration they were loaded and reencoded with random
|
| 13 |
+
parameters using ffmpeg. Due to the mixup, model predictions were “uncertain”, so at the inference stage, model
|
| 14 |
+
confidence was strengthened by a simple transformation. The final prediction was obtained by averaging the predictions
|
| 15 |
+
of models with weights proportional to confidence.The total training and preprocessing time is approximately 5 days on
|
| 16 |
+
DGX-1.
|
| 17 |
+
### Key ingredients
|
| 18 |
+
#### Mixup on aligned real-fake pairs
|
| 19 |
+
One of the main difficulties of this competition is a severe overfitting. Initially, all models overfitted in 2-3 epochs
|
| 20 |
+
(the validation loss started to increase). The idea, which helped a lot with the overfitting, is to train the model on
|
| 21 |
+
a mix of real and fake faces: for each fake face, we take the corresponding real face from the original video (with the
|
| 22 |
+
same box coordinates and the same frame number) an do a linear combination of them. In terms of tensor it’s
|
| 23 |
+
```python
|
| 24 |
+
input_tensor = (1.0 - target) * real_input_tensor + target * fake_input_tensor
|
| 25 |
+
```
|
| 26 |
+
where target is drawn from a Beta distribution with parameters alpha=beta=0.5. With these parameters, there is a very
|
| 27 |
+
high probability of picking values close to 0 or 1 (pure real or pure fake face). You can see the examples below:
|
| 28 |
+

|
| 29 |
+
Due to the fact that real and fake samples are aligned, the background remains almost unchanged on interpolated samples,
|
| 30 |
+
which reduces overfitting and makes the model pay more attention to the face.
|
| 31 |
+
#### Video compression augmentation
|
| 32 |
+
In the paper \[1\] it was pointed out that augmentations close to degradations seen in real-life video distributions
|
| 33 |
+
were applied to the test data. Specifically, these augmentations were (1) reduce the FPS of the video to 15; (2) reduce
|
| 34 |
+
the resolution of the video to 1/4 of its original size; and (3) reduce the overall encoding quality. In order to make
|
| 35 |
+
the model resistant to various parameters of video compression, we added augmentations with random parameters of video
|
| 36 |
+
encoding to training. It would be infeasible to apply such augmentations to the original videos on-the-fly during
|
| 37 |
+
training, so instead of the original videos, cropped (1.5x areas around the face) short (50 frames) clips were used.
|
| 38 |
+
Each clip was saved as separate frames in png format. An example of a clip is given below:
|
| 39 |
+

|
| 40 |
+
For on-the-fly augmentation, ffmpeg-python was used. At each iteration, the following parameters were randomly sampled
|
| 41 |
+
(see \[2\]):
|
| 42 |
+
- FPS (15 to 30)
|
| 43 |
+
- scale (0.25 to 1.0)
|
| 44 |
+
- CRF (17 to 40)
|
| 45 |
+
- random tuning option
|
| 46 |
+
#### Model architecture
|
| 47 |
+
As a result of the experiments, we found out that the EfficientNet models work better than others (we checked ResNet,
|
| 48 |
+
ResNeXt, SE-ResNeXt). The best model was EfficientNet-B7 with Noisy Student pre-trained weights \[3\]. The size of the
|
| 49 |
+
input image is 224x192 (most of the faces in the training dataset are smaller). The final ensemble consists of three
|
| 50 |
+
models, two of which are frame-by-frame, and the third works on sequence.
|
| 51 |
+
##### Frame-by-frame models
|
| 52 |
+
Frame-by-frame models work quite well. They differ in the size of the area around the face and augmentations during
|
| 53 |
+
training. Below are examples of input images for each of the models:
|
| 54 |
+

|
| 55 |
+
##### Sequence-based model
|
| 56 |
+
Probably, time dependencies can be useful for detecting fakes. Therefore, we added a 3d convolution to each block of the
|
| 57 |
+
EfficientNet model. This model worked slightly better than similar frame-by-frame model. The length of the input
|
| 58 |
+
sequence is 7 frames. The step between frames is 1/15 of a second. An example of an input sequence is given below:
|
| 59 |
+

|
| 60 |
+
#### Image augmentations
|
| 61 |
+
To improve model generalization, we used the following augmentations: AutoAugment \[4\], Random Erasing, Random Crops,
|
| 62 |
+
Random Horizontal Flips. Since we used mixup, it was important to augment real-fake pairs the same way (see example).
|
| 63 |
+
For a sequence-based model, it was important to augment frames that belong to the same clip in the same way.
|
| 64 |
+

|
| 65 |
+
#### Inference post-processing
|
| 66 |
+
Due to mixup, the predictions of the models were uncertain, which was not optimal for the logloss. To increase
|
| 67 |
+
confidence, we applied the following transformation:
|
| 68 |
+

|
| 69 |
+
Due to computational limitations, predictions are made on a subsample of frames. Half of the frames were horizontally
|
| 70 |
+
flipped. The prediction for the video is obtained by averaging all the predictions with weights proportional to the
|
| 71 |
+
confidence (the closer the prediction to 0.5, the lower its weight). Such averaging works like attention, because the
|
| 72 |
+
model gives predictions close to 0.5 on poor quality frames (profile faces, blur, etc.).
|
| 73 |
+
#### References
|
| 74 |
+
\[1\] Brian Dolhansky, Russ Howes, Ben Pflaum, Nicole Baram, Cristian Canton Ferrer, “The Deepfake Detection Challenge
|
| 75 |
+
(DFDC) Preview Dataset”
|
| 76 |
+
\[2\] [https://trac.ffmpeg.org/wiki/Encode/H.264](https://trac.ffmpeg.org/wiki/Encode/H.264)
|
| 77 |
+
\[3\] Qizhe Xie, Minh-Thang Luong, Eduard Hovy, Quoc V. Le, “Self-training with Noisy Student improves ImageNet classification”
|
| 78 |
+
\[4\] Ekin D. Cubuk, Barret Zoph, Dandelion Mane, Vijay Vasudevan, Quoc V. Le, “AutoAugment: Learning Augmentation Policies from Data”
|
| 79 |
+
## The hardware we used
|
| 80 |
+
- CPU: Intel(R) Xeon(R) CPU E5-2698 v4 @ 2.20GHz
|
| 81 |
+
- GPU: 8x NVIDIA Tesla V100 SXM2 32 GB
|
| 82 |
+
- RAM: 512 GB
|
| 83 |
+
- SSD: 6 TB
|
| 84 |
+
## Prerequisites
|
| 85 |
+
### Environment
|
| 86 |
+
Use the docker to get an environment close to what was used in the training. Run the following command to build the docker image:
|
| 87 |
+
```bash
|
| 88 |
+
cd path/to/solution
|
| 89 |
+
sudo docker build -t dfdc .
|
| 90 |
+
```
|
| 91 |
+
### Data
|
| 92 |
+
Download the [deepfake-detection-challenge-data](https://www.kaggle.com/c/deepfake-detection-challenge/data) and extract all files to `/path/to/dfdc-data`. This directory must have the following structure:
|
| 93 |
+
```
|
| 94 |
+
dfdc-data
|
| 95 |
+
├── dfdc_train_part_0
|
| 96 |
+
├── dfdc_train_part_1
|
| 97 |
+
├── dfdc_train_part_10
|
| 98 |
+
├── dfdc_train_part_11
|
| 99 |
+
├── dfdc_train_part_12
|
| 100 |
+
├── dfdc_train_part_13
|
| 101 |
+
├── dfdc_train_part_14
|
| 102 |
+
├── dfdc_train_part_15
|
| 103 |
+
├── dfdc_train_part_16
|
| 104 |
+
├── dfdc_train_part_17
|
| 105 |
+
├── dfdc_train_part_18
|
| 106 |
+
├── dfdc_train_part_19
|
| 107 |
+
├── dfdc_train_part_2
|
| 108 |
+
├── dfdc_train_part_20
|
| 109 |
+
├── dfdc_train_part_21
|
| 110 |
+
├── dfdc_train_part_22
|
| 111 |
+
├── dfdc_train_part_23
|
| 112 |
+
├── dfdc_train_part_24
|
| 113 |
+
├── dfdc_train_part_25
|
| 114 |
+
├── dfdc_train_part_26
|
| 115 |
+
├── dfdc_train_part_27
|
| 116 |
+
├── dfdc_train_part_28
|
| 117 |
+
├── dfdc_train_part_29
|
| 118 |
+
├── dfdc_train_part_3
|
| 119 |
+
├── dfdc_train_part_30
|
| 120 |
+
├── dfdc_train_part_31
|
| 121 |
+
├── dfdc_train_part_32
|
| 122 |
+
├── dfdc_train_part_33
|
| 123 |
+
├── dfdc_train_part_34
|
| 124 |
+
├── dfdc_train_part_35
|
| 125 |
+
├── dfdc_train_part_36
|
| 126 |
+
├── dfdc_train_part_37
|
| 127 |
+
├── dfdc_train_part_38
|
| 128 |
+
├── dfdc_train_part_39
|
| 129 |
+
├── dfdc_train_part_4
|
| 130 |
+
├── dfdc_train_part_40
|
| 131 |
+
├── dfdc_train_part_41
|
| 132 |
+
├── dfdc_train_part_42
|
| 133 |
+
├── dfdc_train_part_43
|
| 134 |
+
├── dfdc_train_part_44
|
| 135 |
+
├── dfdc_train_part_45
|
| 136 |
+
├── dfdc_train_part_46
|
| 137 |
+
├── dfdc_train_part_47
|
| 138 |
+
├── dfdc_train_part_48
|
| 139 |
+
├── dfdc_train_part_49
|
| 140 |
+
├── dfdc_train_part_5
|
| 141 |
+
├── dfdc_train_part_6
|
| 142 |
+
├── dfdc_train_part_7
|
| 143 |
+
├── dfdc_train_part_8
|
| 144 |
+
├── dfdc_train_part_9
|
| 145 |
+
└── test_videos
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
### External data
|
| 149 |
+
According to the rules of the competition, external data is allowed. The solution does not use other external data, except for pre-trained models. Below is a table with information about these models.
|
| 150 |
+
|
| 151 |
+
| File Name | Source | Direct Link | Forum Post |
|
| 152 |
+
| --------- | ------ | ----------- | ---------- |
|
| 153 |
+
| WIDERFace_DSFD_RES152.pth | [github](https://github.com/Tencent/FaceDetection-DSFD/tree/31aa8bdeaf01a0c408adaf2709754a16b17aec79) | [google drive](https://drive.google.com/file/d/1WeXlNYsM6dMP3xQQELI-4gxhwKUQxc3-/view) | [link](https://www.kaggle.com/c/deepfake-detection-challenge/discussion/121203#761391) |
|
| 154 |
+
| noisy_student_efficientnet-b7.tar.gz | [github](https://github.com/tensorflow/tpu/tree/4719695c9128622fb26dedb19ea19bd9d1ee3177/models/official/efficientnet) | [link](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b7.tar.gz) | [link](https://www.kaggle.com/c/deepfake-detection-challenge/discussion/121203#748358) |
|
| 155 |
+
|
| 156 |
+
Download these files and copy them to the `external_data` folder.
|
| 157 |
+
|
| 158 |
+
## How to train the model
|
| 159 |
+
Run the docker container with the paths correctly mounted:
|
| 160 |
+
```bash
|
| 161 |
+
sudo docker run --runtime=nvidia -i -t -d --rm --ipc=host -v /path/to/dfdc-data:/kaggle/input/deepfake-detection-challenge:ro -v /path/to/solution:/kaggle/solution --name dfdc dfdc
|
| 162 |
+
sudo docker exec -it dfdc /bin/bash
|
| 163 |
+
cd /kaggle/solution
|
| 164 |
+
```
|
| 165 |
+
Convert pre-trained model from tensorflow to pytorch:
|
| 166 |
+
```bash
|
| 167 |
+
bash convert_tf_to_pt.sh
|
| 168 |
+
```
|
| 169 |
+
Detect faces on videos:
|
| 170 |
+
```bash
|
| 171 |
+
python3.6 detect_faces_on_videos.py
|
| 172 |
+
```
|
| 173 |
+
_Note: You can parallelize this operation using the `--part` and `--num_parts` arguments_
|
| 174 |
+
Generate tracks:
|
| 175 |
+
```bash
|
| 176 |
+
python3.6 generate_tracks.py
|
| 177 |
+
```
|
| 178 |
+
Generate aligned tracks:
|
| 179 |
+
```bash
|
| 180 |
+
python3.6 generate_aligned_tracks.py
|
| 181 |
+
```
|
| 182 |
+
Extract tracks from videos:
|
| 183 |
+
```bash
|
| 184 |
+
python3.6 extract_tracks_from_videos.py
|
| 185 |
+
```
|
| 186 |
+
_Note: You can parallelize this operation using the `--part` and `--num_parts` arguments_
|
| 187 |
+
Generate track pairs:
|
| 188 |
+
```bash
|
| 189 |
+
python3.6 generate_track_pairs.py
|
| 190 |
+
```
|
| 191 |
+
Train models:
|
| 192 |
+
```bash
|
| 193 |
+
python3.6 train_b7_ns_aa_original_large_crop_100k.py
|
| 194 |
+
python3.6 train_b7_ns_aa_original_re_100k.py
|
| 195 |
+
python3.6 train_b7_ns_seq_aa_original_100k.py
|
| 196 |
+
```
|
| 197 |
+
Copy the final weights and convert them to FP16:
|
| 198 |
+
```bash
|
| 199 |
+
python3.6 copy_weights.py
|
| 200 |
+
```
|
| 201 |
+
## Serialized copy of the trained model
|
| 202 |
+
You can download the final weights that were used in the competition (the result of the `copy_weights.py` script): [GoogleDrive](https://drive.google.com/file/d/1S-HeppZcbXDF0F-BO96zhqZyrRWOaan6/view?usp=sharing)
|
| 203 |
+
## How to generate submission
|
| 204 |
+
Run the following command
|
| 205 |
+
```bash
|
| 206 |
+
python3.6 predict.py
|
| 207 |
+
```
|
config.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DFDC_DATA_PATH: "/kaggle/input/deepfake-detection-challenge"
|
| 2 |
+
ARTIFACTS_PATH: "/kaggle/solution/artifacts"
|
| 3 |
+
MODELS_PATH: "/kaggle/solution/models"
|
| 4 |
+
SUBMISSION_PATH: "/kaggle/solution/output/submission.csv"
|
convert_tf_to_pt.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
cd external_data && \
|
| 3 |
+
tar -xzf noisy_student_efficientnet-b7.tar.gz && \
|
| 4 |
+
python3.6 convert_tf_to_pt.py --model_name efficientnet-b7 --tf_checkpoint noisy-student-efficientnet-b7 --output_file noisy_student_efficientnet-b7.pth && \
|
| 5 |
+
rm -rf noisy-student-efficientnet-b7 tmp && \
|
| 6 |
+
cd ..
|
copy_weights.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import yaml
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
WEIGHTS_MAPPING = {
|
| 7 |
+
'snapshots/efficientnet-b7_ns_aa-original-mstd0.5_large_crop_100k/snapshot_100000.pth': 'efficientnet-b7_ns_aa-original-mstd0.5_large_crop_100k_v4_cad79a/snapshot_100000.fp16.pth',
|
| 8 |
+
'snapshots/efficientnet-b7_ns_aa-original-mstd0.5_re_100k/snapshot_100000.pth': 'efficientnet-b7_ns_aa-original-mstd0.5_re_100k_v4_cad79a/snapshot_100000.fp16.pth',
|
| 9 |
+
'snapshots/efficientnet-b7_ns_seq_aa-original-mstd0.5_100k/snapshot_100000.pth': 'efficientnet-b7_ns_seq_aa-original-mstd0.5_100k_v4_cad79a/snapshot_100000.fp16.pth'
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
SRC_DETECTOR_WEIGHTS = 'external_data/WIDERFace_DSFD_RES152.pth'
|
| 13 |
+
DST_DETECTOR_WEIGHTS = 'WIDERFace_DSFD_RES152.fp16.pth'
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def copy_weights(src_path, dst_path):
|
| 17 |
+
state = torch.load(src_path, map_location=lambda storage, loc: storage)
|
| 18 |
+
state = {key: value.half() for key, value in state.items()}
|
| 19 |
+
os.makedirs(os.path.dirname(dst_path), exist_ok=True)
|
| 20 |
+
torch.save(state, dst_path)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def main():
|
| 24 |
+
with open('config.yaml', 'r') as f:
|
| 25 |
+
config = yaml.load(f)
|
| 26 |
+
|
| 27 |
+
for src_rel_path, dst_rel_path in WEIGHTS_MAPPING.items():
|
| 28 |
+
src_path = os.path.join(config['ARTIFACTS_PATH'], src_rel_path)
|
| 29 |
+
dst_path = os.path.join(config['MODELS_PATH'], dst_rel_path)
|
| 30 |
+
copy_weights(src_path, dst_path)
|
| 31 |
+
|
| 32 |
+
copy_weights(SRC_DETECTOR_WEIGHTS, os.path.join(config['MODELS_PATH'], DST_DETECTOR_WEIGHTS))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if __name__ == '__main__':
|
| 36 |
+
main()
|
datasets.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import glob
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class UnlabeledVideoDataset(Dataset):
|
| 12 |
+
def __init__(self, root_dir, content=None, transform=None):
|
| 13 |
+
self.root_dir = os.path.normpath(root_dir)
|
| 14 |
+
self.transform = transform
|
| 15 |
+
|
| 16 |
+
if content is not None:
|
| 17 |
+
self.content = content
|
| 18 |
+
else:
|
| 19 |
+
self.content = []
|
| 20 |
+
for path in glob.iglob(os.path.join(self.root_dir, '**', '*.mp4'), recursive=True):
|
| 21 |
+
rel_path = path[len(self.root_dir) + 1:]
|
| 22 |
+
self.content.append(rel_path)
|
| 23 |
+
self.content = sorted(self.content)
|
| 24 |
+
|
| 25 |
+
def __len__(self):
|
| 26 |
+
return len(self.content)
|
| 27 |
+
|
| 28 |
+
def __getitem__(self, idx):
|
| 29 |
+
rel_path = self.content[idx]
|
| 30 |
+
path = os.path.join(self.root_dir, rel_path)
|
| 31 |
+
|
| 32 |
+
capture = cv2.VideoCapture(path)
|
| 33 |
+
|
| 34 |
+
frames = []
|
| 35 |
+
if capture.isOpened():
|
| 36 |
+
while True:
|
| 37 |
+
ret, frame = capture.read()
|
| 38 |
+
if not ret:
|
| 39 |
+
break
|
| 40 |
+
|
| 41 |
+
if self.transform is not None:
|
| 42 |
+
frame = self.transform(frame)
|
| 43 |
+
|
| 44 |
+
frames.append(frame)
|
| 45 |
+
|
| 46 |
+
sample = {
|
| 47 |
+
'frames': frames,
|
| 48 |
+
'index': idx
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
return sample
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class FaceDataset(Dataset):
|
| 55 |
+
def __init__(self, root_dir, content, labels=None, transform=None):
|
| 56 |
+
self.root_dir = os.path.normpath(root_dir)
|
| 57 |
+
self.content = content
|
| 58 |
+
self.labels = labels
|
| 59 |
+
self.transform = transform
|
| 60 |
+
|
| 61 |
+
def __len__(self):
|
| 62 |
+
return len(self.content)
|
| 63 |
+
|
| 64 |
+
def __getitem__(self, idx):
|
| 65 |
+
rel_path = self.content[idx]
|
| 66 |
+
path = os.path.join(self.root_dir, rel_path)
|
| 67 |
+
|
| 68 |
+
face = cv2.imread(path, cv2.IMREAD_COLOR)
|
| 69 |
+
face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB)
|
| 70 |
+
|
| 71 |
+
if self.transform is not None:
|
| 72 |
+
face = self.transform(image=face)['image']
|
| 73 |
+
|
| 74 |
+
sample = {
|
| 75 |
+
'face': face,
|
| 76 |
+
'index': idx
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
if self.labels is not None:
|
| 80 |
+
sample['label'] = self.labels[idx]
|
| 81 |
+
|
| 82 |
+
return sample
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class TrackPairDataset(Dataset):
|
| 86 |
+
FPS = 30
|
| 87 |
+
|
| 88 |
+
def __init__(self, tracks_root, pairs_path, indices, track_length, track_transform=None, image_transform=None,
|
| 89 |
+
sequence_mode=True):
|
| 90 |
+
self.tracks_root = os.path.normpath(tracks_root)
|
| 91 |
+
self.track_transform = track_transform
|
| 92 |
+
self.image_transform = image_transform
|
| 93 |
+
self.indices = np.asarray(indices, dtype=np.int32)
|
| 94 |
+
self.track_length = track_length
|
| 95 |
+
self.sequence_mode = sequence_mode
|
| 96 |
+
|
| 97 |
+
self.pairs = []
|
| 98 |
+
with open(pairs_path, 'r') as f:
|
| 99 |
+
for line in f:
|
| 100 |
+
real_track, fake_track = line.strip().split(',')
|
| 101 |
+
self.pairs.append((real_track, fake_track))
|
| 102 |
+
|
| 103 |
+
def __len__(self):
|
| 104 |
+
return len(self.pairs)
|
| 105 |
+
|
| 106 |
+
def __getitem__(self, idx):
|
| 107 |
+
real_track_path, fake_track_path = self.pairs[idx]
|
| 108 |
+
|
| 109 |
+
real_track_path = os.path.join(self.tracks_root, real_track_path)
|
| 110 |
+
fake_track_path = os.path.join(self.tracks_root, fake_track_path)
|
| 111 |
+
|
| 112 |
+
if self.track_transform is not None:
|
| 113 |
+
img = self.load_img(real_track_path, 0)
|
| 114 |
+
src_height, src_width = img.shape[:2]
|
| 115 |
+
track_transform_params = self.track_transform.get_params(self.FPS, src_height, src_width)
|
| 116 |
+
else:
|
| 117 |
+
track_transform_params = None
|
| 118 |
+
|
| 119 |
+
real_track = self.load_track(real_track_path, self.indices, track_transform_params)
|
| 120 |
+
fake_track = self.load_track(fake_track_path, self.indices, track_transform_params)
|
| 121 |
+
|
| 122 |
+
if self.image_transform is not None:
|
| 123 |
+
prev_state = random.getstate()
|
| 124 |
+
transformed_real_track = []
|
| 125 |
+
for img in real_track:
|
| 126 |
+
if self.sequence_mode:
|
| 127 |
+
random.setstate(prev_state)
|
| 128 |
+
transformed_real_track.append(self.image_transform(image=img)['image'])
|
| 129 |
+
|
| 130 |
+
real_track = transformed_real_track
|
| 131 |
+
|
| 132 |
+
random.setstate(prev_state)
|
| 133 |
+
transformed_fake_track = []
|
| 134 |
+
for img in fake_track:
|
| 135 |
+
if self.sequence_mode:
|
| 136 |
+
random.setstate(prev_state)
|
| 137 |
+
transformed_fake_track.append(self.image_transform(image=img)['image'])
|
| 138 |
+
fake_track = transformed_fake_track
|
| 139 |
+
|
| 140 |
+
sample = {
|
| 141 |
+
'real': real_track,
|
| 142 |
+
'fake': fake_track
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
return sample
|
| 146 |
+
|
| 147 |
+
def load_img(self, track_path, idx):
|
| 148 |
+
img = cv2.imread(os.path.join(track_path, '{}.png'.format(idx)))
|
| 149 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 150 |
+
|
| 151 |
+
return img
|
| 152 |
+
|
| 153 |
+
def load_track(self, track_path, indices, transform_params):
|
| 154 |
+
if transform_params is None:
|
| 155 |
+
track = np.stack([self.load_img(track_path, idx) for idx in indices])
|
| 156 |
+
else:
|
| 157 |
+
track = self.track_transform(track_path, self.FPS, *transform_params)
|
| 158 |
+
indices = (indices.astype(np.float32) / self.track_length) * len(track)
|
| 159 |
+
indices = np.round(indices).astype(np.int32).clip(0, len(track) - 1)
|
| 160 |
+
track = track[indices]
|
| 161 |
+
|
| 162 |
+
return track
|
detect_faces_on_videos.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import glob
|
| 4 |
+
import yaml
|
| 5 |
+
import pickle
|
| 6 |
+
import tqdm
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
|
| 11 |
+
from dsfacedetector.face_ssd_infer import SSD
|
| 12 |
+
from datasets import UnlabeledVideoDataset
|
| 13 |
+
|
| 14 |
+
DETECTOR_WEIGHTS_PATH = 'external_data/WIDERFace_DSFD_RES152.pth'
|
| 15 |
+
DETECTOR_THRESHOLD = 0.3
|
| 16 |
+
DETECTOR_STEP = 6
|
| 17 |
+
DETECTOR_TARGET_SIZE = (512, 512)
|
| 18 |
+
|
| 19 |
+
BATCH_SIZE = 1
|
| 20 |
+
NUM_WORKERS = 0
|
| 21 |
+
|
| 22 |
+
DETECTIONS_ROOT = 'detections'
|
| 23 |
+
DETECTIONS_FILE_NAME = 'detections.pkl'
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main():
|
| 27 |
+
parser = argparse.ArgumentParser(description='Detects faces on videos')
|
| 28 |
+
parser.add_argument('--num_parts', type=int, default=1, help='Number of parts')
|
| 29 |
+
parser.add_argument('--part', type=int, default=0, help='Part index')
|
| 30 |
+
|
| 31 |
+
args = parser.parse_args()
|
| 32 |
+
|
| 33 |
+
with open('config.yaml', 'r') as f:
|
| 34 |
+
config = yaml.load(f)
|
| 35 |
+
|
| 36 |
+
content = []
|
| 37 |
+
for path in glob.iglob(os.path.join(config['DFDC_DATA_PATH'], 'dfdc_train_part_*', '*.mp4')):
|
| 38 |
+
parts = path.split('/')
|
| 39 |
+
content.append('/'.join(parts[-2:]))
|
| 40 |
+
content = sorted(content)
|
| 41 |
+
|
| 42 |
+
print('Total number of videos: {}'.format(len(content)))
|
| 43 |
+
|
| 44 |
+
part_size = len(content) // args.num_parts + 1
|
| 45 |
+
assert part_size * args.num_parts >= len(content)
|
| 46 |
+
part_start = part_size * args.part
|
| 47 |
+
part_end = min(part_start + part_size, len(content))
|
| 48 |
+
print('Part {} ({}, {})'.format(args.part, part_start, part_end))
|
| 49 |
+
|
| 50 |
+
dataset = UnlabeledVideoDataset(config['DFDC_DATA_PATH'], content[part_start:part_end])
|
| 51 |
+
|
| 52 |
+
detector = SSD('test')
|
| 53 |
+
state = torch.load(DETECTOR_WEIGHTS_PATH, map_location=lambda storage, loc: storage)
|
| 54 |
+
detector.load_state_dict(state)
|
| 55 |
+
device = torch.device('cuda')
|
| 56 |
+
detector = detector.eval().to(device)
|
| 57 |
+
|
| 58 |
+
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, collate_fn=lambda X: X,
|
| 59 |
+
drop_last=False)
|
| 60 |
+
|
| 61 |
+
dst_root = os.path.join(config['ARTIFACTS_PATH'], DETECTIONS_ROOT)
|
| 62 |
+
os.makedirs(dst_root, exist_ok=True)
|
| 63 |
+
|
| 64 |
+
for video_sample in tqdm.tqdm(loader):
|
| 65 |
+
frames = video_sample[0]['frames']
|
| 66 |
+
video_idx = video_sample[0]['index']
|
| 67 |
+
video_rel_path = dataset.content[video_idx]
|
| 68 |
+
|
| 69 |
+
detections = []
|
| 70 |
+
for frame in frames[::DETECTOR_STEP]:
|
| 71 |
+
with torch.no_grad():
|
| 72 |
+
detections_per_frame = detector.detect_on_image(frame, DETECTOR_TARGET_SIZE, device, is_pad=False,
|
| 73 |
+
keep_thresh=DETECTOR_THRESHOLD)
|
| 74 |
+
detections.append({'boxes': detections_per_frame[:, :4], 'scores': detections_per_frame[:, 4]})
|
| 75 |
+
|
| 76 |
+
os.makedirs(os.path.join(dst_root, video_rel_path), exist_ok=True)
|
| 77 |
+
with open(os.path.join(dst_root, video_rel_path, DETECTIONS_FILE_NAME), 'wb') as f:
|
| 78 |
+
pickle.dump(detections, f)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
if __name__ == '__main__':
|
| 82 |
+
main()
|
dsfacedetector/__init__.py
ADDED
|
File without changes
|
dsfacedetector/data/__init__.py
ADDED
|
File without changes
|
dsfacedetector/data/config.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def test_base_transform(image, mean):
|
| 5 |
+
x = image.astype(np.float32)
|
| 6 |
+
x -= mean
|
| 7 |
+
x = x.astype(np.float32)
|
| 8 |
+
return x
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TestBaseTransform:
|
| 12 |
+
def __init__(self, mean):
|
| 13 |
+
self.mean = np.array(mean, dtype=np.float32)
|
| 14 |
+
|
| 15 |
+
def __call__(self, image):
|
| 16 |
+
return test_base_transform(image, self.mean)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
widerface_640 = {
|
| 20 |
+
'num_classes': 2,
|
| 21 |
+
|
| 22 |
+
'feature_maps': [160, 80, 40, 20, 10, 5],
|
| 23 |
+
'min_dim': 640,
|
| 24 |
+
|
| 25 |
+
'steps': [4, 8, 16, 32, 64, 128], # stride
|
| 26 |
+
|
| 27 |
+
'variance': [0.1, 0.2],
|
| 28 |
+
'clip': True, # make default box in [0,1]
|
| 29 |
+
'name': 'WIDERFace',
|
| 30 |
+
'l2norm_scale': [10, 8, 5],
|
| 31 |
+
'base': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M', 512, 512, 512],
|
| 32 |
+
'extras': [256, 'S', 512, 128, 'S', 256],
|
| 33 |
+
|
| 34 |
+
'mbox': [1, 1, 1, 1, 1, 1],
|
| 35 |
+
'min_sizes': [16, 32, 64, 128, 256, 512],
|
| 36 |
+
'max_sizes': [],
|
| 37 |
+
'aspect_ratios': [[1.5], [1.5], [1.5], [1.5], [1.5], [1.5]], # [1,2] default 1
|
| 38 |
+
|
| 39 |
+
'backbone': 'resnet152',
|
| 40 |
+
'feature_pyramid_network': True,
|
| 41 |
+
'bottom_up_path': False,
|
| 42 |
+
'feature_enhance_module': True,
|
| 43 |
+
'max_in_out': True,
|
| 44 |
+
'focal_loss': False,
|
| 45 |
+
'progressive_anchor': True,
|
| 46 |
+
'refinedet': False,
|
| 47 |
+
'max_out': False,
|
| 48 |
+
'anchor_compensation': False,
|
| 49 |
+
'data_anchor_sampling': False,
|
| 50 |
+
|
| 51 |
+
'overlap_thresh': [0.4],
|
| 52 |
+
'negpos_ratio': 3,
|
| 53 |
+
# test
|
| 54 |
+
'nms_thresh': 0.3,
|
| 55 |
+
'conf_thresh': 0.01,
|
| 56 |
+
'num_thresh': 5000,
|
| 57 |
+
}
|
dsfacedetector/face_ssd_infer.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Source: https://github.com/vlad3996/FaceDetection-DSFD
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torchvision
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from .data.config import TestBaseTransform, widerface_640 as cfg
|
| 8 |
+
from .layers import Detect, get_prior_boxes, FEM, pa_multibox, mio_module, upsample_product
|
| 9 |
+
from .utils import resize_image
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SSD(nn.Module):
|
| 13 |
+
|
| 14 |
+
def __init__(self, phase, nms_thresh=0.3, nms_conf_thresh=0.01):
|
| 15 |
+
super(SSD, self).__init__()
|
| 16 |
+
self.phase = phase
|
| 17 |
+
self.num_classes = 2
|
| 18 |
+
self.cfg = cfg
|
| 19 |
+
|
| 20 |
+
resnet = torchvision.models.resnet152(pretrained=False)
|
| 21 |
+
|
| 22 |
+
self.layer1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1)
|
| 23 |
+
self.layer2 = nn.Sequential(resnet.layer2)
|
| 24 |
+
self.layer3 = nn.Sequential(resnet.layer3)
|
| 25 |
+
self.layer4 = nn.Sequential(resnet.layer4)
|
| 26 |
+
self.layer5 = nn.Sequential(
|
| 27 |
+
*[nn.Conv2d(2048, 512, kernel_size=1),
|
| 28 |
+
nn.BatchNorm2d(512),
|
| 29 |
+
nn.ReLU(inplace=True),
|
| 30 |
+
nn.Conv2d(512, 512, kernel_size=3, padding=1, stride=2),
|
| 31 |
+
nn.BatchNorm2d(512),
|
| 32 |
+
nn.ReLU(inplace=True)]
|
| 33 |
+
)
|
| 34 |
+
self.layer6 = nn.Sequential(
|
| 35 |
+
*[nn.Conv2d(512, 128, kernel_size=1, ),
|
| 36 |
+
nn.BatchNorm2d(128),
|
| 37 |
+
nn.ReLU(inplace=True),
|
| 38 |
+
nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2),
|
| 39 |
+
nn.BatchNorm2d(256),
|
| 40 |
+
nn.ReLU(inplace=True)]
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
output_channels = [256, 512, 1024, 2048, 512, 256]
|
| 44 |
+
|
| 45 |
+
# FPN
|
| 46 |
+
fpn_in = output_channels
|
| 47 |
+
|
| 48 |
+
self.latlayer3 = nn.Conv2d(fpn_in[3], fpn_in[2], kernel_size=1, stride=1, padding=0)
|
| 49 |
+
self.latlayer2 = nn.Conv2d(fpn_in[2], fpn_in[1], kernel_size=1, stride=1, padding=0)
|
| 50 |
+
self.latlayer1 = nn.Conv2d(fpn_in[1], fpn_in[0], kernel_size=1, stride=1, padding=0)
|
| 51 |
+
|
| 52 |
+
self.smooth3 = nn.Conv2d(fpn_in[2], fpn_in[2], kernel_size=1, stride=1, padding=0)
|
| 53 |
+
self.smooth2 = nn.Conv2d(fpn_in[1], fpn_in[1], kernel_size=1, stride=1, padding=0)
|
| 54 |
+
self.smooth1 = nn.Conv2d(fpn_in[0], fpn_in[0], kernel_size=1, stride=1, padding=0)
|
| 55 |
+
|
| 56 |
+
# FEM
|
| 57 |
+
cpm_in = output_channels
|
| 58 |
+
|
| 59 |
+
self.cpm3_3 = FEM(cpm_in[0])
|
| 60 |
+
self.cpm4_3 = FEM(cpm_in[1])
|
| 61 |
+
self.cpm5_3 = FEM(cpm_in[2])
|
| 62 |
+
self.cpm7 = FEM(cpm_in[3])
|
| 63 |
+
self.cpm6_2 = FEM(cpm_in[4])
|
| 64 |
+
self.cpm7_2 = FEM(cpm_in[5])
|
| 65 |
+
|
| 66 |
+
# head
|
| 67 |
+
head = pa_multibox(output_channels)
|
| 68 |
+
self.loc = nn.ModuleList(head[0])
|
| 69 |
+
self.conf = nn.ModuleList(head[1])
|
| 70 |
+
|
| 71 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 72 |
+
|
| 73 |
+
if self.phase != 'onnx_export':
|
| 74 |
+
self.detect = Detect(self.num_classes, 0, cfg['num_thresh'], nms_conf_thresh, nms_thresh,
|
| 75 |
+
cfg['variance'])
|
| 76 |
+
self.last_image_size = None
|
| 77 |
+
self.last_feature_maps = None
|
| 78 |
+
|
| 79 |
+
if self.phase == 'test':
|
| 80 |
+
self.test_transform = TestBaseTransform((104, 117, 123))
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
|
| 84 |
+
image_size = [x.shape[2], x.shape[3]]
|
| 85 |
+
loc = list()
|
| 86 |
+
conf = list()
|
| 87 |
+
|
| 88 |
+
conv3_3_x = self.layer1(x)
|
| 89 |
+
conv4_3_x = self.layer2(conv3_3_x)
|
| 90 |
+
conv5_3_x = self.layer3(conv4_3_x)
|
| 91 |
+
fc7_x = self.layer4(conv5_3_x)
|
| 92 |
+
conv6_2_x = self.layer5(fc7_x)
|
| 93 |
+
conv7_2_x = self.layer6(conv6_2_x)
|
| 94 |
+
|
| 95 |
+
lfpn3 = upsample_product(self.latlayer3(fc7_x), self.smooth3(conv5_3_x))
|
| 96 |
+
lfpn2 = upsample_product(self.latlayer2(lfpn3), self.smooth2(conv4_3_x))
|
| 97 |
+
lfpn1 = upsample_product(self.latlayer1(lfpn2), self.smooth1(conv3_3_x))
|
| 98 |
+
|
| 99 |
+
conv5_3_x = lfpn3
|
| 100 |
+
conv4_3_x = lfpn2
|
| 101 |
+
conv3_3_x = lfpn1
|
| 102 |
+
|
| 103 |
+
sources = [conv3_3_x, conv4_3_x, conv5_3_x, fc7_x, conv6_2_x, conv7_2_x]
|
| 104 |
+
|
| 105 |
+
sources[0] = self.cpm3_3(sources[0])
|
| 106 |
+
sources[1] = self.cpm4_3(sources[1])
|
| 107 |
+
sources[2] = self.cpm5_3(sources[2])
|
| 108 |
+
sources[3] = self.cpm7(sources[3])
|
| 109 |
+
sources[4] = self.cpm6_2(sources[4])
|
| 110 |
+
sources[5] = self.cpm7_2(sources[5])
|
| 111 |
+
|
| 112 |
+
# apply multibox head to source layers
|
| 113 |
+
featuremap_size = []
|
| 114 |
+
for (x, l, c) in zip(sources, self.loc, self.conf):
|
| 115 |
+
featuremap_size.append([x.shape[2], x.shape[3]])
|
| 116 |
+
loc.append(l(x).permute(0, 2, 3, 1).contiguous())
|
| 117 |
+
len_conf = len(conf)
|
| 118 |
+
cls = mio_module(c(x), len_conf)
|
| 119 |
+
conf.append(cls.permute(0, 2, 3, 1).contiguous())
|
| 120 |
+
|
| 121 |
+
face_loc = torch.cat([o[:, :, :, :4].contiguous().view(o.size(0), -1) for o in loc], 1)
|
| 122 |
+
face_loc = face_loc.view(face_loc.size(0), -1, 4)
|
| 123 |
+
face_conf = torch.cat([o[:, :, :, :2].contiguous().view(o.size(0), -1) for o in conf], 1)
|
| 124 |
+
face_conf = self.softmax(face_conf.view(face_conf.size(0), -1, self.num_classes))
|
| 125 |
+
|
| 126 |
+
if self.phase != 'onnx_export':
|
| 127 |
+
|
| 128 |
+
if self.last_image_size is None or self.last_image_size != image_size or self.last_feature_maps != featuremap_size:
|
| 129 |
+
self.priors = get_prior_boxes(self.cfg, featuremap_size, image_size).to(face_loc.device)
|
| 130 |
+
self.last_image_size = image_size
|
| 131 |
+
self.last_feature_maps = featuremap_size
|
| 132 |
+
with torch.no_grad():
|
| 133 |
+
output = self.detect(face_loc, face_conf, self.priors)
|
| 134 |
+
else:
|
| 135 |
+
output = torch.cat((face_loc, face_conf), 2)
|
| 136 |
+
return output
|
| 137 |
+
|
| 138 |
+
def detect_on_image(self, source_image, target_size, device, is_pad=False, keep_thresh=0.3):
|
| 139 |
+
|
| 140 |
+
image, shift_h_scaled, shift_w_scaled, scale = resize_image(source_image, target_size, is_pad=is_pad)
|
| 141 |
+
|
| 142 |
+
x = torch.from_numpy(self.test_transform(image)).permute(2, 0, 1).to(device)
|
| 143 |
+
x.unsqueeze_(0)
|
| 144 |
+
|
| 145 |
+
detections = self.forward(x).cpu().numpy()
|
| 146 |
+
|
| 147 |
+
scores = detections[0, 1, :, 0]
|
| 148 |
+
keep_idxs = scores > keep_thresh # find keeping indexes
|
| 149 |
+
detections = detections[0, 1, keep_idxs, :] # select detections over threshold
|
| 150 |
+
detections = detections[:, [1, 2, 3, 4, 0]] # reorder
|
| 151 |
+
|
| 152 |
+
detections[:, [0, 2]] -= shift_w_scaled # 0 or pad percent from left corner
|
| 153 |
+
detections[:, [1, 3]] -= shift_h_scaled # 0 or pad percent from top
|
| 154 |
+
detections[:, :4] *= scale
|
| 155 |
+
|
| 156 |
+
return detections
|
dsfacedetector/layers/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .detection import Detect
|
| 2 |
+
from .prior_box import PriorBox, get_prior_boxes
|
| 3 |
+
from .modules import FEM, pa_multibox, mio_module, upsample_product
|
dsfacedetector/layers/detection.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import division
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Detect(nn.Module):
|
| 7 |
+
"""At test time, Detect is the final layer of SSD. Decode location preds,
|
| 8 |
+
apply non-maximum suppression to location predictions based on conf
|
| 9 |
+
scores and threshold to a top_k number of output predictions for both
|
| 10 |
+
confidence score and locations.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh, variance=(0.1, 0.2)):
|
| 14 |
+
super(Detect, self).__init__()
|
| 15 |
+
self.num_classes = num_classes
|
| 16 |
+
self.background_label = bkg_label
|
| 17 |
+
self.top_k = top_k
|
| 18 |
+
# Parameters used in nms.
|
| 19 |
+
self.nms_thresh = nms_thresh
|
| 20 |
+
if nms_thresh <= 0:
|
| 21 |
+
raise ValueError('nms_threshold must be non negative.')
|
| 22 |
+
self.conf_thresh = conf_thresh
|
| 23 |
+
self.variance = variance
|
| 24 |
+
|
| 25 |
+
def forward(self, loc_data, conf_data, prior_data):
|
| 26 |
+
"""
|
| 27 |
+
Args:
|
| 28 |
+
loc_data: (tensor) Loc preds from loc layers
|
| 29 |
+
Shape: [batch,num_priors*4]
|
| 30 |
+
conf_data: (tensor) Shape: Conf preds from conf layers
|
| 31 |
+
Shape: [batch*num_priors,num_classes]
|
| 32 |
+
prior_data: (tensor) Prior boxes and variances from priorbox layers
|
| 33 |
+
Shape: [1,num_priors,4]
|
| 34 |
+
"""
|
| 35 |
+
num = loc_data.size(0) # batch size
|
| 36 |
+
num_priors = prior_data.size(0)
|
| 37 |
+
|
| 38 |
+
output = torch.zeros(num, self.num_classes, self.top_k, 5)
|
| 39 |
+
conf_preds = conf_data.view(num, num_priors, self.num_classes).transpose(2, 1)
|
| 40 |
+
|
| 41 |
+
# Decode predictions into bboxes.
|
| 42 |
+
for i in range(num):
|
| 43 |
+
default = prior_data
|
| 44 |
+
decoded_boxes = decode(loc_data[i], default, self.variance)
|
| 45 |
+
# For each class, perform nms
|
| 46 |
+
conf_scores = conf_preds[i].clone()
|
| 47 |
+
|
| 48 |
+
for cl in range(1, self.num_classes):
|
| 49 |
+
c_mask = conf_scores[cl].gt(self.conf_thresh)
|
| 50 |
+
scores = conf_scores[cl][c_mask]
|
| 51 |
+
if scores.dim() == 0 or scores.size(0) == 0:
|
| 52 |
+
continue
|
| 53 |
+
l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)
|
| 54 |
+
boxes = decoded_boxes[l_mask].view(-1, 4)
|
| 55 |
+
# idx of highest scoring and non-overlapping boxes per class
|
| 56 |
+
ids, count = nms(boxes, scores, self.nms_thresh, self.top_k)
|
| 57 |
+
output[i, cl, :count] = \
|
| 58 |
+
torch.cat((scores[ids[:count]].unsqueeze(1),
|
| 59 |
+
boxes[ids[:count]]), 1)
|
| 60 |
+
flt = output.contiguous().view(num, -1, 5)
|
| 61 |
+
_, idx = flt[:, :, 0].sort(1, descending=True)
|
| 62 |
+
_, rank = idx.sort(1)
|
| 63 |
+
flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0)
|
| 64 |
+
return output
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# Adapted from https://github.com/Hakuyume/chainer-ssd
|
| 68 |
+
def decode(loc, priors, variances):
|
| 69 |
+
"""Decode locations from predictions using priors to undo
|
| 70 |
+
the encoding we did for offset regression at train time.
|
| 71 |
+
Args:
|
| 72 |
+
loc (tensor): location predictions for loc layers,
|
| 73 |
+
Shape: [num_priors,4]
|
| 74 |
+
priors (tensor): Prior boxes in center-offset form.
|
| 75 |
+
Shape: [num_priors,4].
|
| 76 |
+
variances: (list[float]) Variances of priorboxes
|
| 77 |
+
Return:
|
| 78 |
+
decoded bounding box predictions
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
boxes = torch.cat((
|
| 82 |
+
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
|
| 83 |
+
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
|
| 84 |
+
boxes[:, :2] -= boxes[:, 2:] / 2
|
| 85 |
+
boxes[:, 2:] += boxes[:, :2]
|
| 86 |
+
# (cx,cy,w,h)->(x0,y0,x1,y1)
|
| 87 |
+
return boxes
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# Original author: Francisco Massa:
|
| 91 |
+
# https://github.com/fmassa/object-detection.torch
|
| 92 |
+
# Ported to PyTorch by Max deGroot (02/01/2017)
|
| 93 |
+
def nms(boxes, scores, overlap=0.5, top_k=200):
|
| 94 |
+
"""Apply non-maximum suppression at test time to avoid detecting too many
|
| 95 |
+
overlapping bounding boxes for a given object.
|
| 96 |
+
Args:
|
| 97 |
+
boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
|
| 98 |
+
scores: (tensor) The class predscores for the img, Shape:[num_priors].
|
| 99 |
+
overlap: (float) The overlap thresh for suppressing unnecessary boxes.
|
| 100 |
+
top_k: (int) The Maximum number of box preds to consider.
|
| 101 |
+
Return:
|
| 102 |
+
The indices of the kept boxes with respect to num_priors.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
keep = scores.new(scores.size(0)).zero_().long()
|
| 106 |
+
if boxes.numel() == 0:
|
| 107 |
+
return keep
|
| 108 |
+
x1 = boxes[:, 0]
|
| 109 |
+
y1 = boxes[:, 1]
|
| 110 |
+
x2 = boxes[:, 2]
|
| 111 |
+
y2 = boxes[:, 3]
|
| 112 |
+
area = torch.mul(x2 - x1, y2 - y1)
|
| 113 |
+
v, idx = scores.sort(0) # sort in ascending order
|
| 114 |
+
# I = I[v >= 0.01]
|
| 115 |
+
idx = idx[-top_k:] # indices of the top-k largest vals
|
| 116 |
+
xx1 = boxes.new()
|
| 117 |
+
yy1 = boxes.new()
|
| 118 |
+
xx2 = boxes.new()
|
| 119 |
+
yy2 = boxes.new()
|
| 120 |
+
w = boxes.new()
|
| 121 |
+
h = boxes.new()
|
| 122 |
+
|
| 123 |
+
# keep = torch.Tensor()
|
| 124 |
+
count = 0
|
| 125 |
+
while idx.numel() > 0:
|
| 126 |
+
i = idx[-1] # index of current largest val
|
| 127 |
+
# keep.append(i)
|
| 128 |
+
keep[count] = i
|
| 129 |
+
count += 1
|
| 130 |
+
if idx.size(0) == 1:
|
| 131 |
+
break
|
| 132 |
+
idx = idx[:-1] # remove kept element from view
|
| 133 |
+
# load bboxes of next highest vals
|
| 134 |
+
torch.index_select(x1, 0, idx, out=xx1)
|
| 135 |
+
torch.index_select(y1, 0, idx, out=yy1)
|
| 136 |
+
torch.index_select(x2, 0, idx, out=xx2)
|
| 137 |
+
torch.index_select(y2, 0, idx, out=yy2)
|
| 138 |
+
# store element-wise max with next highest score
|
| 139 |
+
xx1 = torch.clamp(xx1, min=x1[i])
|
| 140 |
+
yy1 = torch.clamp(yy1, min=y1[i])
|
| 141 |
+
xx2 = torch.clamp(xx2, max=x2[i])
|
| 142 |
+
yy2 = torch.clamp(yy2, max=y2[i])
|
| 143 |
+
w.resize_as_(xx2)
|
| 144 |
+
h.resize_as_(yy2)
|
| 145 |
+
w = xx2 - xx1
|
| 146 |
+
h = yy2 - yy1
|
| 147 |
+
# check sizes of xx1 and xx2.. after each iteration
|
| 148 |
+
w = torch.clamp(w, min=0.0)
|
| 149 |
+
h = torch.clamp(h, min=0.0)
|
| 150 |
+
inter = w * h
|
| 151 |
+
# IoU = i / (area(a) + area(b) - i)
|
| 152 |
+
rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
|
| 153 |
+
union = (rem_areas - inter) + area[i]
|
| 154 |
+
IoU = inter / union # store result in iou
|
| 155 |
+
# keep only elements with an IoU <= overlap
|
| 156 |
+
idx = idx[IoU.le(overlap)]
|
| 157 |
+
return keep, count
|
dsfacedetector/layers/modules.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class DeepHeadModule(nn.Module):
|
| 7 |
+
def __init__(self, input_channels, output_channels):
|
| 8 |
+
super(DeepHeadModule, self).__init__()
|
| 9 |
+
self._input_channels = input_channels
|
| 10 |
+
self._output_channels = output_channels
|
| 11 |
+
self._mid_channels = min(self._input_channels, 256)
|
| 12 |
+
|
| 13 |
+
self.conv1 = nn.Conv2d(self._input_channels, self._mid_channels, kernel_size=3, dilation=1, stride=1, padding=1)
|
| 14 |
+
self.conv2 = nn.Conv2d(self._mid_channels, self._mid_channels, kernel_size=3, dilation=1, stride=1, padding=1)
|
| 15 |
+
self.conv3 = nn.Conv2d(self._mid_channels, self._mid_channels, kernel_size=3, dilation=1, stride=1, padding=1)
|
| 16 |
+
self.conv4 = nn.Conv2d(self._mid_channels, self._output_channels, kernel_size=1, dilation=1, stride=1,
|
| 17 |
+
padding=0)
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
return self.conv4(
|
| 21 |
+
F.relu(self.conv3(F.relu(self.conv2(F.relu(self.conv1(x), inplace=True)), inplace=True)), inplace=True))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class FEM(nn.Module):
|
| 25 |
+
def __init__(self, channel_size):
|
| 26 |
+
super(FEM, self).__init__()
|
| 27 |
+
self.cs = channel_size
|
| 28 |
+
self.cpm1 = nn.Conv2d(self.cs, 256, kernel_size=3, dilation=1, stride=1, padding=1)
|
| 29 |
+
self.cpm2 = nn.Conv2d(self.cs, 256, kernel_size=3, dilation=2, stride=1, padding=2)
|
| 30 |
+
self.cpm3 = nn.Conv2d(256, 128, kernel_size=3, dilation=1, stride=1, padding=1)
|
| 31 |
+
self.cpm4 = nn.Conv2d(256, 128, kernel_size=3, dilation=2, stride=1, padding=2)
|
| 32 |
+
self.cpm5 = nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1)
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
x1_1 = F.relu(self.cpm1(x), inplace=True)
|
| 36 |
+
x1_2 = F.relu(self.cpm2(x), inplace=True)
|
| 37 |
+
x2_1 = F.relu(self.cpm3(x1_2), inplace=True)
|
| 38 |
+
x2_2 = F.relu(self.cpm4(x1_2), inplace=True)
|
| 39 |
+
x3_1 = F.relu(self.cpm5(x2_2), inplace=True)
|
| 40 |
+
return torch.cat((x1_1, x2_1, x3_1), 1)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def upsample_product(x, y):
|
| 44 |
+
'''Upsample and add two feature maps.
|
| 45 |
+
Args:
|
| 46 |
+
x: (Variable) top feature map to be upsampled.
|
| 47 |
+
y: (Variable) lateral feature map.
|
| 48 |
+
Returns:
|
| 49 |
+
(Variable) added feature map.
|
| 50 |
+
Note in PyTorch, when input size is odd, the upsampled feature map
|
| 51 |
+
with `F.upsample(..., scale_factor=2, mode='nearest')`
|
| 52 |
+
maybe not equal to the lateral feature map size.
|
| 53 |
+
e.g.
|
| 54 |
+
original input size: [N,_,15,15] ->
|
| 55 |
+
conv2d feature map size: [N,_,8,8] ->
|
| 56 |
+
upsampled feature map size: [N,_,16,16]
|
| 57 |
+
So we choose bilinear upsample which supports arbitrary output sizes.
|
| 58 |
+
'''
|
| 59 |
+
_, _, H, W = y.size()
|
| 60 |
+
|
| 61 |
+
# FOR ONNX CONVERSION
|
| 62 |
+
# return F.interpolate(x, scale_factor=2, mode='nearest') * y
|
| 63 |
+
return F.interpolate(x, size=(int(H), int(W)), mode='bilinear', align_corners=False) * y
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def pa_multibox(output_channels):
|
| 67 |
+
loc_layers = []
|
| 68 |
+
conf_layers = []
|
| 69 |
+
for k, v in enumerate(output_channels):
|
| 70 |
+
if k == 0:
|
| 71 |
+
loc_output = 4
|
| 72 |
+
conf_output = 2
|
| 73 |
+
elif k == 1:
|
| 74 |
+
loc_output = 8
|
| 75 |
+
conf_output = 4
|
| 76 |
+
else:
|
| 77 |
+
loc_output = 12
|
| 78 |
+
conf_output = 6
|
| 79 |
+
loc_layers += [DeepHeadModule(512, loc_output)]
|
| 80 |
+
conf_layers += [DeepHeadModule(512, (2 + conf_output))]
|
| 81 |
+
return (loc_layers, conf_layers)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def mio_module(each_mmbox, len_conf, your_mind_state='peasant'):
|
| 85 |
+
# chunk = torch.split(each_mmbox, 1, 1) - !!!!! failed to export on PyTorch v1.0.1 (ONNX version 1.3)
|
| 86 |
+
chunk = torch.chunk(each_mmbox, int(each_mmbox.shape[1]), 1)
|
| 87 |
+
|
| 88 |
+
# some hacks for ONNX and Inference Engine export
|
| 89 |
+
if your_mind_state == 'peasant':
|
| 90 |
+
bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
|
| 91 |
+
elif your_mind_state == 'advanced':
|
| 92 |
+
bmax = torch.max(each_mmbox[:, :3], 1)[0].unsqueeze(0)
|
| 93 |
+
else: # supermind
|
| 94 |
+
bmax = torch.nn.functional.max_pool3d(each_mmbox[:, :3], kernel_size=(3, 1, 1))
|
| 95 |
+
|
| 96 |
+
cls = (torch.cat((bmax, chunk[3]), dim=1) if len_conf == 0 else torch.cat((chunk[3], bmax), dim=1))
|
| 97 |
+
cls = torch.cat((cls, *list(chunk[4:])), dim=1)
|
| 98 |
+
return cls
|
dsfacedetector/layers/prior_box.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import division
|
| 2 |
+
from math import sqrt as sqrt
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class PriorBox(object):
|
| 7 |
+
"""Compute priorbox coordinates in center-offset form for each source
|
| 8 |
+
feature map.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, cfg, min_size, max_size):
|
| 12 |
+
super(PriorBox, self).__init__()
|
| 13 |
+
self.image_size = cfg['min_dim']
|
| 14 |
+
self.feature_maps = cfg['feature_maps']
|
| 15 |
+
|
| 16 |
+
self.variance = cfg['variance'] or [0.1]
|
| 17 |
+
self.min_sizes = min_size
|
| 18 |
+
self.max_sizes = max_size
|
| 19 |
+
self.steps = cfg['steps']
|
| 20 |
+
self.aspect_ratios = cfg['aspect_ratios']
|
| 21 |
+
self.clip = cfg['clip']
|
| 22 |
+
|
| 23 |
+
for v in self.variance:
|
| 24 |
+
if v <= 0:
|
| 25 |
+
raise ValueError('Variances must be greater than 0')
|
| 26 |
+
|
| 27 |
+
def forward(self):
|
| 28 |
+
|
| 29 |
+
mean = []
|
| 30 |
+
|
| 31 |
+
if len(self.min_sizes) == 5:
|
| 32 |
+
self.feature_maps = self.feature_maps[1:]
|
| 33 |
+
self.steps = self.steps[1:]
|
| 34 |
+
if len(self.min_sizes) == 4:
|
| 35 |
+
self.feature_maps = self.feature_maps[2:]
|
| 36 |
+
self.steps = self.steps[2:]
|
| 37 |
+
|
| 38 |
+
for k, f in enumerate(self.feature_maps):
|
| 39 |
+
# for i, j in product(range(f), repeat=2):
|
| 40 |
+
for i in range(f[0]):
|
| 41 |
+
for j in range(f[1]):
|
| 42 |
+
# f_k = self.image_size / self.steps[k]
|
| 43 |
+
f_k_i = self.image_size[0] / self.steps[k]
|
| 44 |
+
f_k_j = self.image_size[1] / self.steps[k]
|
| 45 |
+
# unit center x,y
|
| 46 |
+
cx = (j + 0.5) / f_k_j
|
| 47 |
+
cy = (i + 0.5) / f_k_i
|
| 48 |
+
# aspect_ratio: 1
|
| 49 |
+
# rel size: min_size
|
| 50 |
+
s_k_i = self.min_sizes[k] / self.image_size[1]
|
| 51 |
+
s_k_j = self.min_sizes[k] / self.image_size[0]
|
| 52 |
+
# swordli@tencent
|
| 53 |
+
if len(self.aspect_ratios[0]) == 0:
|
| 54 |
+
mean += [cx, cy, s_k_i, s_k_j]
|
| 55 |
+
|
| 56 |
+
# aspect_ratio: 1
|
| 57 |
+
# rel size: sqrt(s_k * s_(k+1))
|
| 58 |
+
# s_k_prime = sqrt(s_k * (self.max_sizes[k]/self.image_size))
|
| 59 |
+
if len(self.max_sizes) == len(self.min_sizes):
|
| 60 |
+
s_k_prime_i = sqrt(s_k_i * (self.max_sizes[k] / self.image_size[1]))
|
| 61 |
+
s_k_prime_j = sqrt(s_k_j * (self.max_sizes[k] / self.image_size[0]))
|
| 62 |
+
mean += [cx, cy, s_k_prime_i, s_k_prime_j]
|
| 63 |
+
# rest of aspect ratios
|
| 64 |
+
for ar in self.aspect_ratios[k]:
|
| 65 |
+
if len(self.max_sizes) == len(self.min_sizes):
|
| 66 |
+
mean += [cx, cy, s_k_prime_i / sqrt(ar), s_k_prime_j * sqrt(ar)]
|
| 67 |
+
mean += [cx, cy, s_k_i / sqrt(ar), s_k_j * sqrt(ar)]
|
| 68 |
+
|
| 69 |
+
# back to torch land
|
| 70 |
+
output = torch.Tensor(mean).view(-1, 4)
|
| 71 |
+
if self.clip:
|
| 72 |
+
output.clamp_(max=1, min=0)
|
| 73 |
+
return output
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def get_prior_boxes(cfg, feature_maps, image_size):
|
| 77 |
+
|
| 78 |
+
# number of priors for feature map location (either 4 or 6)
|
| 79 |
+
variance = cfg['variance'] or [0.1]
|
| 80 |
+
min_sizes = cfg['min_sizes']
|
| 81 |
+
max_sizes = cfg['max_sizes']
|
| 82 |
+
steps = cfg['steps']
|
| 83 |
+
aspect_ratios = cfg['aspect_ratios']
|
| 84 |
+
clip = cfg['clip']
|
| 85 |
+
for v in variance:
|
| 86 |
+
if v <= 0:
|
| 87 |
+
raise ValueError('Variances must be greater than 0')
|
| 88 |
+
|
| 89 |
+
mean = []
|
| 90 |
+
|
| 91 |
+
if len(min_sizes) == 5:
|
| 92 |
+
feature_maps = feature_maps[1:]
|
| 93 |
+
steps = steps[1:]
|
| 94 |
+
if len(min_sizes) == 4:
|
| 95 |
+
feature_maps = feature_maps[2:]
|
| 96 |
+
steps = steps[2:]
|
| 97 |
+
|
| 98 |
+
for k, f in enumerate(feature_maps):
|
| 99 |
+
# for i, j in product(range(f), repeat=2):
|
| 100 |
+
for i in range(f[0]):
|
| 101 |
+
for j in range(f[1]):
|
| 102 |
+
# f_k = image_size / steps[k]
|
| 103 |
+
f_k_i = image_size[0] / steps[k]
|
| 104 |
+
f_k_j = image_size[1] / steps[k]
|
| 105 |
+
# unit center x,y
|
| 106 |
+
cx = (j + 0.5) / f_k_j
|
| 107 |
+
cy = (i + 0.5) / f_k_i
|
| 108 |
+
# aspect_ratio: 1
|
| 109 |
+
# rel size: min_size
|
| 110 |
+
s_k_i = min_sizes[k] / image_size[1]
|
| 111 |
+
s_k_j = min_sizes[k] / image_size[0]
|
| 112 |
+
# swordli@tencent
|
| 113 |
+
if len(aspect_ratios[0]) == 0:
|
| 114 |
+
mean += [cx, cy, s_k_i, s_k_j]
|
| 115 |
+
|
| 116 |
+
# aspect_ratio: 1
|
| 117 |
+
# rel size: sqrt(s_k * s_(k+1))
|
| 118 |
+
# s_k_prime = sqrt(s_k * (max_sizes[k]/image_size))
|
| 119 |
+
if len(max_sizes) == len(min_sizes):
|
| 120 |
+
s_k_prime_i = sqrt(s_k_i * (max_sizes[k] / image_size[1]))
|
| 121 |
+
s_k_prime_j = sqrt(s_k_j * (max_sizes[k] / image_size[0]))
|
| 122 |
+
mean += [cx, cy, s_k_prime_i, s_k_prime_j]
|
| 123 |
+
# rest of aspect ratios
|
| 124 |
+
for ar in aspect_ratios[k]:
|
| 125 |
+
if len(max_sizes) == len(min_sizes):
|
| 126 |
+
mean += [cx, cy, s_k_prime_i / sqrt(ar), s_k_prime_j * sqrt(ar)]
|
| 127 |
+
mean += [cx, cy, s_k_i / sqrt(ar), s_k_j * sqrt(ar)]
|
| 128 |
+
|
| 129 |
+
# back to torch land
|
| 130 |
+
output = torch.Tensor(mean).view(-1, 4)
|
| 131 |
+
if clip:
|
| 132 |
+
output.clamp_(max=1, min=0)
|
| 133 |
+
return output
|
dsfacedetector/utils.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def vis_detections(im, dets, thresh=0.5, show_text=True):
|
| 7 |
+
"""Draw detected bounding boxes."""
|
| 8 |
+
class_name = 'face'
|
| 9 |
+
inds = np.where(dets[:, -1] >= thresh)[0] if dets is not None else []
|
| 10 |
+
if len(inds) == 0:
|
| 11 |
+
return
|
| 12 |
+
im = im[:, :, (2, 1, 0)]
|
| 13 |
+
fig, ax = plt.subplots(figsize=(12, 12))
|
| 14 |
+
ax.imshow(im, aspect='equal')
|
| 15 |
+
for i in inds:
|
| 16 |
+
bbox = dets[i, :4]
|
| 17 |
+
score = dets[i, -1]
|
| 18 |
+
ax.add_patch(
|
| 19 |
+
plt.Rectangle((bbox[0], bbox[1]),
|
| 20 |
+
bbox[2] - bbox[0],
|
| 21 |
+
bbox[3] - bbox[1], fill=False,
|
| 22 |
+
edgecolor='red', linewidth=2.5)
|
| 23 |
+
)
|
| 24 |
+
if show_text:
|
| 25 |
+
ax.text(bbox[0], bbox[1] - 5,
|
| 26 |
+
'{:s} {:.3f}'.format(class_name, score),
|
| 27 |
+
bbox=dict(facecolor='blue', alpha=0.5),
|
| 28 |
+
fontsize=10, color='white')
|
| 29 |
+
ax.set_title(('{} detections with '
|
| 30 |
+
'p({} | box) >= {:.1f}').format(class_name, class_name,
|
| 31 |
+
thresh),
|
| 32 |
+
fontsize=10)
|
| 33 |
+
plt.axis('off')
|
| 34 |
+
plt.tight_layout()
|
| 35 |
+
plt.savefig('out.png')
|
| 36 |
+
plt.show()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def bbox_vote(det):
|
| 40 |
+
order = det[:, 4].ravel().argsort()[::-1]
|
| 41 |
+
det = det[order, :]
|
| 42 |
+
dets = None
|
| 43 |
+
while det.shape[0] > 0:
|
| 44 |
+
# IOU
|
| 45 |
+
area = (det[:, 2] - det[:, 0] + 1) * (det[:, 3] - det[:, 1] + 1)
|
| 46 |
+
xx1 = np.maximum(det[0, 0], det[:, 0])
|
| 47 |
+
yy1 = np.maximum(det[0, 1], det[:, 1])
|
| 48 |
+
xx2 = np.minimum(det[0, 2], det[:, 2])
|
| 49 |
+
yy2 = np.minimum(det[0, 3], det[:, 3])
|
| 50 |
+
w = np.maximum(0.0, xx2 - xx1 + 1)
|
| 51 |
+
h = np.maximum(0.0, yy2 - yy1 + 1)
|
| 52 |
+
inter = w * h
|
| 53 |
+
o = inter / (area[0] + area[:] - inter)
|
| 54 |
+
# get needed merge det and delete these det
|
| 55 |
+
merge_index = np.where(o >= 0.3)[0]
|
| 56 |
+
det_accu = det[merge_index, :]
|
| 57 |
+
det = np.delete(det, merge_index, 0)
|
| 58 |
+
if merge_index.shape[0] <= 1:
|
| 59 |
+
continue
|
| 60 |
+
det_accu[:, 0:4] = det_accu[:, 0:4] * np.tile(det_accu[:, -1:], (1, 4))
|
| 61 |
+
max_score = np.max(det_accu[:, 4])
|
| 62 |
+
det_accu_sum = np.zeros((1, 5))
|
| 63 |
+
det_accu_sum[:, 0:4] = np.sum(det_accu[:, 0:4], axis=0) / np.sum(det_accu[:, -1:])
|
| 64 |
+
det_accu_sum[:, 4] = max_score
|
| 65 |
+
try:
|
| 66 |
+
dets = np.row_stack((dets, det_accu_sum))
|
| 67 |
+
except:
|
| 68 |
+
dets = det_accu_sum
|
| 69 |
+
if dets is not None:
|
| 70 |
+
dets = dets[0:750, :]
|
| 71 |
+
return dets
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def add_borders(curr_img, target_shape=(224, 224), fill_type=0):
|
| 75 |
+
curr_h, curr_w = curr_img.shape[0:2]
|
| 76 |
+
shift_h = max(target_shape[0] - curr_h, 0)
|
| 77 |
+
shift_w = max(target_shape[1] - curr_w, 0)
|
| 78 |
+
|
| 79 |
+
image = cv2.copyMakeBorder(curr_img, shift_h // 2, (shift_h + 1) // 2, shift_w // 2, (shift_w + 1) // 2, fill_type)
|
| 80 |
+
return image, shift_h, shift_w
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def resize_image(image, target_size, resize_factor=None, is_pad=True, interpolation=3):
|
| 84 |
+
curr_image_size = image.shape[0:2]
|
| 85 |
+
|
| 86 |
+
if resize_factor is None and is_pad:
|
| 87 |
+
resize_factor = min(target_size[0] / curr_image_size[0], target_size[1] / curr_image_size[1])
|
| 88 |
+
elif resize_factor is None and not is_pad:
|
| 89 |
+
resize_factor = np.sqrt((target_size[0] * target_size[1]) / (curr_image_size[0] * curr_image_size[1]))
|
| 90 |
+
|
| 91 |
+
image = cv2.resize(image, None, None, fx=resize_factor, fy=resize_factor, interpolation=interpolation)
|
| 92 |
+
|
| 93 |
+
if is_pad:
|
| 94 |
+
image, shift_h, shift_w = add_borders(image, target_size)
|
| 95 |
+
else:
|
| 96 |
+
shift_h = shift_w = 0
|
| 97 |
+
|
| 98 |
+
scale = np.array([image.shape[1]/resize_factor, image.shape[0]/resize_factor,
|
| 99 |
+
image.shape[1]/resize_factor, image.shape[0]/resize_factor])
|
| 100 |
+
|
| 101 |
+
return image, shift_h/image.shape[0]/2, shift_w/image.shape[1]/2, scale
|
external_data/convert_tf_to_pt.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Source: https://github.com/lukemelas/EfficientNet-PyTorch
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import tensorflow as tf
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
def load_param(checkpoint_file, conversion_table, model_name):
|
| 8 |
+
"""
|
| 9 |
+
Load parameters according to conversion_table.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
checkpoint_file (string): pretrained checkpoint model file in tensorflow
|
| 13 |
+
conversion_table (dict): { pytorch tensor in a model : checkpoint variable name }
|
| 14 |
+
"""
|
| 15 |
+
for pyt_param, tf_param_name in conversion_table.items():
|
| 16 |
+
tf_param_name = str(model_name) + '/' + tf_param_name
|
| 17 |
+
tf_param = tf.train.load_variable(checkpoint_file, tf_param_name)
|
| 18 |
+
if 'conv' in tf_param_name and 'kernel' in tf_param_name:
|
| 19 |
+
tf_param = np.transpose(tf_param, (3, 2, 0, 1))
|
| 20 |
+
if 'depthwise' in tf_param_name:
|
| 21 |
+
tf_param = np.transpose(tf_param, (1, 0, 2, 3))
|
| 22 |
+
elif tf_param_name.endswith('kernel'): # for weight(kernel), we should do transpose
|
| 23 |
+
tf_param = np.transpose(tf_param)
|
| 24 |
+
assert pyt_param.size() == tf_param.shape, \
|
| 25 |
+
'Dim Mismatch: %s vs %s ; %s' % (tuple(pyt_param.size()), tf_param.shape, tf_param_name)
|
| 26 |
+
pyt_param.data = torch.from_numpy(tf_param)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_efficientnet(model, checkpoint_file, model_name):
|
| 30 |
+
"""
|
| 31 |
+
Load PyTorch EfficientNet from TensorFlow checkpoint file
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
# This will store the enire conversion table
|
| 35 |
+
conversion_table = {}
|
| 36 |
+
merge = lambda dict1, dict2: {**dict1, **dict2}
|
| 37 |
+
|
| 38 |
+
# All the weights not in the conv blocks
|
| 39 |
+
conversion_table_for_weights_outside_blocks = {
|
| 40 |
+
model._conv_stem.weight: 'stem/conv2d/kernel', # [3, 3, 3, 32]),
|
| 41 |
+
model._bn0.bias: 'stem/tpu_batch_normalization/beta', # [32]),
|
| 42 |
+
model._bn0.weight: 'stem/tpu_batch_normalization/gamma', # [32]),
|
| 43 |
+
model._bn0.running_mean: 'stem/tpu_batch_normalization/moving_mean', # [32]),
|
| 44 |
+
model._bn0.running_var: 'stem/tpu_batch_normalization/moving_variance', # [32]),
|
| 45 |
+
model._conv_head.weight: 'head/conv2d/kernel', # [1, 1, 320, 1280]),
|
| 46 |
+
model._bn1.bias: 'head/tpu_batch_normalization/beta', # [1280]),
|
| 47 |
+
model._bn1.weight: 'head/tpu_batch_normalization/gamma', # [1280]),
|
| 48 |
+
model._bn1.running_mean: 'head/tpu_batch_normalization/moving_mean', # [32]),
|
| 49 |
+
model._bn1.running_var: 'head/tpu_batch_normalization/moving_variance', # [32]),
|
| 50 |
+
model._fc.bias: 'head/dense/bias', # [1000]),
|
| 51 |
+
model._fc.weight: 'head/dense/kernel', # [1280, 1000]),
|
| 52 |
+
}
|
| 53 |
+
conversion_table = merge(conversion_table, conversion_table_for_weights_outside_blocks)
|
| 54 |
+
|
| 55 |
+
# The first conv block is special because it does not have _expand_conv
|
| 56 |
+
conversion_table_for_first_block = {
|
| 57 |
+
model._blocks[0]._project_conv.weight: 'blocks_0/conv2d/kernel', # 1, 1, 32, 16]),
|
| 58 |
+
model._blocks[0]._depthwise_conv.weight: 'blocks_0/depthwise_conv2d/depthwise_kernel', # [3, 3, 32, 1]),
|
| 59 |
+
model._blocks[0]._se_reduce.bias: 'blocks_0/se/conv2d/bias', # , [8]),
|
| 60 |
+
model._blocks[0]._se_reduce.weight: 'blocks_0/se/conv2d/kernel', # , [1, 1, 32, 8]),
|
| 61 |
+
model._blocks[0]._se_expand.bias: 'blocks_0/se/conv2d_1/bias', # , [32]),
|
| 62 |
+
model._blocks[0]._se_expand.weight: 'blocks_0/se/conv2d_1/kernel', # , [1, 1, 8, 32]),
|
| 63 |
+
model._blocks[0]._bn1.bias: 'blocks_0/tpu_batch_normalization/beta', # [32]),
|
| 64 |
+
model._blocks[0]._bn1.weight: 'blocks_0/tpu_batch_normalization/gamma', # [32]),
|
| 65 |
+
model._blocks[0]._bn1.running_mean: 'blocks_0/tpu_batch_normalization/moving_mean',
|
| 66 |
+
model._blocks[0]._bn1.running_var: 'blocks_0/tpu_batch_normalization/moving_variance',
|
| 67 |
+
model._blocks[0]._bn2.bias: 'blocks_0/tpu_batch_normalization_1/beta', # [16]),
|
| 68 |
+
model._blocks[0]._bn2.weight: 'blocks_0/tpu_batch_normalization_1/gamma', # [16]),
|
| 69 |
+
model._blocks[0]._bn2.running_mean: 'blocks_0/tpu_batch_normalization_1/moving_mean',
|
| 70 |
+
model._blocks[0]._bn2.running_var: 'blocks_0/tpu_batch_normalization_1/moving_variance',
|
| 71 |
+
}
|
| 72 |
+
conversion_table = merge(conversion_table, conversion_table_for_first_block)
|
| 73 |
+
|
| 74 |
+
# Conv blocks
|
| 75 |
+
for i in range(len(model._blocks)):
|
| 76 |
+
|
| 77 |
+
is_first_block = '_expand_conv.weight' not in [n for n, p in model._blocks[i].named_parameters()]
|
| 78 |
+
|
| 79 |
+
if is_first_block:
|
| 80 |
+
conversion_table_block = {
|
| 81 |
+
model._blocks[i]._project_conv.weight: 'blocks_' + str(i) + '/conv2d/kernel', # 1, 1, 32, 16]),
|
| 82 |
+
model._blocks[i]._depthwise_conv.weight: 'blocks_' + str(i) + '/depthwise_conv2d/depthwise_kernel',
|
| 83 |
+
# [3, 3, 32, 1]),
|
| 84 |
+
model._blocks[i]._se_reduce.bias: 'blocks_' + str(i) + '/se/conv2d/bias', # , [8]),
|
| 85 |
+
model._blocks[i]._se_reduce.weight: 'blocks_' + str(i) + '/se/conv2d/kernel', # , [1, 1, 32, 8]),
|
| 86 |
+
model._blocks[i]._se_expand.bias: 'blocks_' + str(i) + '/se/conv2d_1/bias', # , [32]),
|
| 87 |
+
model._blocks[i]._se_expand.weight: 'blocks_' + str(i) + '/se/conv2d_1/kernel', # , [1, 1, 8, 32]),
|
| 88 |
+
model._blocks[i]._bn1.bias: 'blocks_' + str(i) + '/tpu_batch_normalization/beta', # [32]),
|
| 89 |
+
model._blocks[i]._bn1.weight: 'blocks_' + str(i) + '/tpu_batch_normalization/gamma', # [32]),
|
| 90 |
+
model._blocks[i]._bn1.running_mean: 'blocks_' + str(i) + '/tpu_batch_normalization/moving_mean',
|
| 91 |
+
model._blocks[i]._bn1.running_var: 'blocks_' + str(i) + '/tpu_batch_normalization/moving_variance',
|
| 92 |
+
model._blocks[i]._bn2.bias: 'blocks_' + str(i) + '/tpu_batch_normalization_1/beta', # [16]),
|
| 93 |
+
model._blocks[i]._bn2.weight: 'blocks_' + str(i) + '/tpu_batch_normalization_1/gamma', # [16]),
|
| 94 |
+
model._blocks[i]._bn2.running_mean: 'blocks_' + str(i) + '/tpu_batch_normalization_1/moving_mean',
|
| 95 |
+
model._blocks[i]._bn2.running_var: 'blocks_' + str(i) + '/tpu_batch_normalization_1/moving_variance',
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
else:
|
| 99 |
+
conversion_table_block = {
|
| 100 |
+
model._blocks[i]._expand_conv.weight: 'blocks_' + str(i) + '/conv2d/kernel',
|
| 101 |
+
model._blocks[i]._project_conv.weight: 'blocks_' + str(i) + '/conv2d_1/kernel',
|
| 102 |
+
model._blocks[i]._depthwise_conv.weight: 'blocks_' + str(i) + '/depthwise_conv2d/depthwise_kernel',
|
| 103 |
+
model._blocks[i]._se_reduce.bias: 'blocks_' + str(i) + '/se/conv2d/bias',
|
| 104 |
+
model._blocks[i]._se_reduce.weight: 'blocks_' + str(i) + '/se/conv2d/kernel',
|
| 105 |
+
model._blocks[i]._se_expand.bias: 'blocks_' + str(i) + '/se/conv2d_1/bias',
|
| 106 |
+
model._blocks[i]._se_expand.weight: 'blocks_' + str(i) + '/se/conv2d_1/kernel',
|
| 107 |
+
model._blocks[i]._bn0.bias: 'blocks_' + str(i) + '/tpu_batch_normalization/beta',
|
| 108 |
+
model._blocks[i]._bn0.weight: 'blocks_' + str(i) + '/tpu_batch_normalization/gamma',
|
| 109 |
+
model._blocks[i]._bn0.running_mean: 'blocks_' + str(i) + '/tpu_batch_normalization/moving_mean',
|
| 110 |
+
model._blocks[i]._bn0.running_var: 'blocks_' + str(i) + '/tpu_batch_normalization/moving_variance',
|
| 111 |
+
model._blocks[i]._bn1.bias: 'blocks_' + str(i) + '/tpu_batch_normalization_1/beta',
|
| 112 |
+
model._blocks[i]._bn1.weight: 'blocks_' + str(i) + '/tpu_batch_normalization_1/gamma',
|
| 113 |
+
model._blocks[i]._bn1.running_mean: 'blocks_' + str(i) + '/tpu_batch_normalization_1/moving_mean',
|
| 114 |
+
model._blocks[i]._bn1.running_var: 'blocks_' + str(i) + '/tpu_batch_normalization_1/moving_variance',
|
| 115 |
+
model._blocks[i]._bn2.bias: 'blocks_' + str(i) + '/tpu_batch_normalization_2/beta',
|
| 116 |
+
model._blocks[i]._bn2.weight: 'blocks_' + str(i) + '/tpu_batch_normalization_2/gamma',
|
| 117 |
+
model._blocks[i]._bn2.running_mean: 'blocks_' + str(i) + '/tpu_batch_normalization_2/moving_mean',
|
| 118 |
+
model._blocks[i]._bn2.running_var: 'blocks_' + str(i) + '/tpu_batch_normalization_2/moving_variance',
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
conversion_table = merge(conversion_table, conversion_table_block)
|
| 122 |
+
|
| 123 |
+
# Load TensorFlow parameters into PyTorch model
|
| 124 |
+
load_param(checkpoint_file, conversion_table, model_name)
|
| 125 |
+
return conversion_table
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def load_and_save_temporary_tensorflow_model(model_name, model_ckpt, example_img= '../../example/img.jpg'):
|
| 129 |
+
""" Loads and saves a TensorFlow model. """
|
| 130 |
+
image_files = [example_img]
|
| 131 |
+
eval_ckpt_driver = eval_ckpt_main.EvalCkptDriver(model_name)
|
| 132 |
+
with tf.Graph().as_default(), tf.Session() as sess:
|
| 133 |
+
images, labels = eval_ckpt_driver.build_dataset(image_files, [0] * len(image_files), False)
|
| 134 |
+
probs = eval_ckpt_driver.build_model(images, is_training=False)
|
| 135 |
+
sess.run(tf.global_variables_initializer())
|
| 136 |
+
print(model_ckpt)
|
| 137 |
+
eval_ckpt_driver.restore_model(sess, model_ckpt)
|
| 138 |
+
tf.train.Saver().save(sess, 'tmp/model.ckpt')
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
if __name__ == '__main__':
|
| 142 |
+
|
| 143 |
+
import sys
|
| 144 |
+
import argparse
|
| 145 |
+
|
| 146 |
+
sys.path.append('original_tf')
|
| 147 |
+
import eval_ckpt_main
|
| 148 |
+
|
| 149 |
+
from efficientnet_pytorch import EfficientNet
|
| 150 |
+
|
| 151 |
+
parser = argparse.ArgumentParser(
|
| 152 |
+
description='Convert TF model to PyTorch model and save for easier future loading')
|
| 153 |
+
parser.add_argument('--model_name', type=str, default='efficientnet-b0',
|
| 154 |
+
help='efficientnet-b{N}, where N is an integer 0 <= N <= 8')
|
| 155 |
+
parser.add_argument('--tf_checkpoint', type=str, default='pretrained_tensorflow/efficientnet-b0/',
|
| 156 |
+
help='checkpoint file path')
|
| 157 |
+
parser.add_argument('--output_file', type=str, default='pretrained_pytorch/efficientnet-b0.pth',
|
| 158 |
+
help='output PyTorch model file name')
|
| 159 |
+
args = parser.parse_args()
|
| 160 |
+
|
| 161 |
+
# Build model
|
| 162 |
+
model = EfficientNet.from_name(args.model_name)
|
| 163 |
+
|
| 164 |
+
# Load and save temporary TensorFlow file due to TF nuances
|
| 165 |
+
print(args.tf_checkpoint)
|
| 166 |
+
load_and_save_temporary_tensorflow_model(args.model_name, args.tf_checkpoint)
|
| 167 |
+
|
| 168 |
+
# Load weights
|
| 169 |
+
load_efficientnet(model, 'tmp/model.ckpt', model_name=args.model_name)
|
| 170 |
+
print('Loaded TF checkpoint weights')
|
| 171 |
+
|
| 172 |
+
# Save PyTorch file
|
| 173 |
+
torch.save(model.state_dict(), args.output_file)
|
| 174 |
+
print('Saved model to', args.output_file)
|
external_data/original_tf/__init__.py
ADDED
|
File without changes
|
external_data/original_tf/efficientnet_builder.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Model Builder for EfficientNet."""
|
| 16 |
+
|
| 17 |
+
from __future__ import absolute_import
|
| 18 |
+
from __future__ import division
|
| 19 |
+
from __future__ import print_function
|
| 20 |
+
|
| 21 |
+
import functools
|
| 22 |
+
import os
|
| 23 |
+
import re
|
| 24 |
+
from absl import logging
|
| 25 |
+
import numpy as np
|
| 26 |
+
import six
|
| 27 |
+
import tensorflow.compat.v1 as tf
|
| 28 |
+
|
| 29 |
+
import efficientnet_model
|
| 30 |
+
import utils
|
| 31 |
+
MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255]
|
| 32 |
+
STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def efficientnet_params(model_name):
|
| 36 |
+
"""Get efficientnet params based on model name."""
|
| 37 |
+
params_dict = {
|
| 38 |
+
# (width_coefficient, depth_coefficient, resolution, dropout_rate)
|
| 39 |
+
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
|
| 40 |
+
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
|
| 41 |
+
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
|
| 42 |
+
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
|
| 43 |
+
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
|
| 44 |
+
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
|
| 45 |
+
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
|
| 46 |
+
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
|
| 47 |
+
'efficientnet-b8': (2.2, 3.6, 672, 0.5),
|
| 48 |
+
'efficientnet-l2': (4.3, 5.3, 800, 0.5),
|
| 49 |
+
}
|
| 50 |
+
return params_dict[model_name]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class BlockDecoder(object):
|
| 54 |
+
"""Block Decoder for readability."""
|
| 55 |
+
|
| 56 |
+
def _decode_block_string(self, block_string):
|
| 57 |
+
"""Gets a block through a string notation of arguments."""
|
| 58 |
+
if six.PY2:
|
| 59 |
+
assert isinstance(block_string, (str, unicode))
|
| 60 |
+
else:
|
| 61 |
+
assert isinstance(block_string, str)
|
| 62 |
+
ops = block_string.split('_')
|
| 63 |
+
options = {}
|
| 64 |
+
for op in ops:
|
| 65 |
+
splits = re.split(r'(\d.*)', op)
|
| 66 |
+
if len(splits) >= 2:
|
| 67 |
+
key, value = splits[:2]
|
| 68 |
+
options[key] = value
|
| 69 |
+
|
| 70 |
+
if 's' not in options or len(options['s']) != 2:
|
| 71 |
+
raise ValueError('Strides options should be a pair of integers.')
|
| 72 |
+
|
| 73 |
+
return efficientnet_model.BlockArgs(
|
| 74 |
+
kernel_size=int(options['k']),
|
| 75 |
+
num_repeat=int(options['r']),
|
| 76 |
+
input_filters=int(options['i']),
|
| 77 |
+
output_filters=int(options['o']),
|
| 78 |
+
expand_ratio=int(options['e']),
|
| 79 |
+
id_skip=('noskip' not in block_string),
|
| 80 |
+
se_ratio=float(options['se']) if 'se' in options else None,
|
| 81 |
+
strides=[int(options['s'][0]),
|
| 82 |
+
int(options['s'][1])],
|
| 83 |
+
conv_type=int(options['c']) if 'c' in options else 0,
|
| 84 |
+
fused_conv=int(options['f']) if 'f' in options else 0,
|
| 85 |
+
super_pixel=int(options['p']) if 'p' in options else 0,
|
| 86 |
+
condconv=('cc' in block_string))
|
| 87 |
+
|
| 88 |
+
def _encode_block_string(self, block):
|
| 89 |
+
"""Encodes a block to a string."""
|
| 90 |
+
args = [
|
| 91 |
+
'r%d' % block.num_repeat,
|
| 92 |
+
'k%d' % block.kernel_size,
|
| 93 |
+
's%d%d' % (block.strides[0], block.strides[1]),
|
| 94 |
+
'e%s' % block.expand_ratio,
|
| 95 |
+
'i%d' % block.input_filters,
|
| 96 |
+
'o%d' % block.output_filters,
|
| 97 |
+
'c%d' % block.conv_type,
|
| 98 |
+
'f%d' % block.fused_conv,
|
| 99 |
+
'p%d' % block.super_pixel,
|
| 100 |
+
]
|
| 101 |
+
if block.se_ratio > 0 and block.se_ratio <= 1:
|
| 102 |
+
args.append('se%s' % block.se_ratio)
|
| 103 |
+
if block.id_skip is False: # pylint: disable=g-bool-id-comparison
|
| 104 |
+
args.append('noskip')
|
| 105 |
+
if block.condconv:
|
| 106 |
+
args.append('cc')
|
| 107 |
+
return '_'.join(args)
|
| 108 |
+
|
| 109 |
+
def decode(self, string_list):
|
| 110 |
+
"""Decodes a list of string notations to specify blocks inside the network.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
string_list: a list of strings, each string is a notation of block.
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
A list of namedtuples to represent blocks arguments.
|
| 117 |
+
"""
|
| 118 |
+
assert isinstance(string_list, list)
|
| 119 |
+
blocks_args = []
|
| 120 |
+
for block_string in string_list:
|
| 121 |
+
blocks_args.append(self._decode_block_string(block_string))
|
| 122 |
+
return blocks_args
|
| 123 |
+
|
| 124 |
+
def encode(self, blocks_args):
|
| 125 |
+
"""Encodes a list of Blocks to a list of strings.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
blocks_args: A list of namedtuples to represent blocks arguments.
|
| 129 |
+
Returns:
|
| 130 |
+
a list of strings, each string is a notation of block.
|
| 131 |
+
"""
|
| 132 |
+
block_strings = []
|
| 133 |
+
for block in blocks_args:
|
| 134 |
+
block_strings.append(self._encode_block_string(block))
|
| 135 |
+
return block_strings
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def swish(features, use_native=True, use_hard=False):
|
| 139 |
+
"""Computes the Swish activation function.
|
| 140 |
+
|
| 141 |
+
We provide three alternnatives:
|
| 142 |
+
- Native tf.nn.swish, use less memory during training than composable swish.
|
| 143 |
+
- Quantization friendly hard swish.
|
| 144 |
+
- A composable swish, equivalant to tf.nn.swish, but more general for
|
| 145 |
+
finetuning and TF-Hub.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
features: A `Tensor` representing preactivation values.
|
| 149 |
+
use_native: Whether to use the native swish from tf.nn that uses a custom
|
| 150 |
+
gradient to reduce memory usage, or to use customized swish that uses
|
| 151 |
+
default TensorFlow gradient computation.
|
| 152 |
+
use_hard: Whether to use quantization-friendly hard swish.
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
The activation value.
|
| 156 |
+
"""
|
| 157 |
+
if use_native and use_hard:
|
| 158 |
+
raise ValueError('Cannot specify both use_native and use_hard.')
|
| 159 |
+
|
| 160 |
+
if use_native:
|
| 161 |
+
return tf.nn.swish(features)
|
| 162 |
+
|
| 163 |
+
if use_hard:
|
| 164 |
+
return features * tf.nn.relu6(features + np.float32(3)) * (1. / 6.)
|
| 165 |
+
|
| 166 |
+
features = tf.convert_to_tensor(features, name='features')
|
| 167 |
+
return features * tf.nn.sigmoid(features)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
_DEFAULT_BLOCKS_ARGS = [
|
| 171 |
+
'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25',
|
| 172 |
+
'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25',
|
| 173 |
+
'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25',
|
| 174 |
+
'r1_k3_s11_e6_i192_o320_se0.25',
|
| 175 |
+
]
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def efficientnet(width_coefficient=None,
|
| 179 |
+
depth_coefficient=None,
|
| 180 |
+
dropout_rate=0.2,
|
| 181 |
+
survival_prob=0.8):
|
| 182 |
+
"""Creates a efficientnet model."""
|
| 183 |
+
global_params = efficientnet_model.GlobalParams(
|
| 184 |
+
blocks_args=_DEFAULT_BLOCKS_ARGS,
|
| 185 |
+
batch_norm_momentum=0.99,
|
| 186 |
+
batch_norm_epsilon=1e-3,
|
| 187 |
+
dropout_rate=dropout_rate,
|
| 188 |
+
survival_prob=survival_prob,
|
| 189 |
+
data_format='channels_last',
|
| 190 |
+
num_classes=1000,
|
| 191 |
+
width_coefficient=width_coefficient,
|
| 192 |
+
depth_coefficient=depth_coefficient,
|
| 193 |
+
depth_divisor=8,
|
| 194 |
+
min_depth=None,
|
| 195 |
+
relu_fn=tf.nn.swish,
|
| 196 |
+
# The default is TPU-specific batch norm.
|
| 197 |
+
# The alternative is tf.layers.BatchNormalization.
|
| 198 |
+
batch_norm=utils.TpuBatchNormalization, # TPU-specific requirement.
|
| 199 |
+
use_se=True,
|
| 200 |
+
clip_projection_output=False)
|
| 201 |
+
return global_params
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def get_model_params(model_name, override_params):
|
| 205 |
+
"""Get the block args and global params for a given model."""
|
| 206 |
+
if model_name.startswith('efficientnet'):
|
| 207 |
+
width_coefficient, depth_coefficient, _, dropout_rate = (
|
| 208 |
+
efficientnet_params(model_name))
|
| 209 |
+
global_params = efficientnet(
|
| 210 |
+
width_coefficient, depth_coefficient, dropout_rate)
|
| 211 |
+
else:
|
| 212 |
+
raise NotImplementedError('model name is not pre-defined: %s' % model_name)
|
| 213 |
+
|
| 214 |
+
if override_params:
|
| 215 |
+
# ValueError will be raised here if override_params has fields not included
|
| 216 |
+
# in global_params.
|
| 217 |
+
global_params = global_params._replace(**override_params)
|
| 218 |
+
|
| 219 |
+
decoder = BlockDecoder()
|
| 220 |
+
blocks_args = decoder.decode(global_params.blocks_args)
|
| 221 |
+
|
| 222 |
+
logging.info('global_params= %s', global_params)
|
| 223 |
+
return blocks_args, global_params
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def build_model(images,
|
| 227 |
+
model_name,
|
| 228 |
+
training,
|
| 229 |
+
override_params=None,
|
| 230 |
+
model_dir=None,
|
| 231 |
+
fine_tuning=False,
|
| 232 |
+
features_only=False,
|
| 233 |
+
pooled_features_only=False):
|
| 234 |
+
"""A helper functiion to creates a model and returns predicted logits.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
images: input images tensor.
|
| 238 |
+
model_name: string, the predefined model name.
|
| 239 |
+
training: boolean, whether the model is constructed for training.
|
| 240 |
+
override_params: A dictionary of params for overriding. Fields must exist in
|
| 241 |
+
efficientnet_model.GlobalParams.
|
| 242 |
+
model_dir: string, optional model dir for saving configs.
|
| 243 |
+
fine_tuning: boolean, whether the model is used for finetuning.
|
| 244 |
+
features_only: build the base feature network only (excluding final
|
| 245 |
+
1x1 conv layer, global pooling, dropout and fc head).
|
| 246 |
+
pooled_features_only: build the base network for features extraction (after
|
| 247 |
+
1x1 conv layer and global pooling, but before dropout and fc head).
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
logits: the logits tensor of classes.
|
| 251 |
+
endpoints: the endpoints for each layer.
|
| 252 |
+
|
| 253 |
+
Raises:
|
| 254 |
+
When model_name specified an undefined model, raises NotImplementedError.
|
| 255 |
+
When override_params has invalid fields, raises ValueError.
|
| 256 |
+
"""
|
| 257 |
+
assert isinstance(images, tf.Tensor)
|
| 258 |
+
assert not (features_only and pooled_features_only)
|
| 259 |
+
|
| 260 |
+
# For backward compatibility.
|
| 261 |
+
if override_params and override_params.get('drop_connect_rate', None):
|
| 262 |
+
override_params['survival_prob'] = 1 - override_params['drop_connect_rate']
|
| 263 |
+
|
| 264 |
+
if not training or fine_tuning:
|
| 265 |
+
if not override_params:
|
| 266 |
+
override_params = {}
|
| 267 |
+
override_params['batch_norm'] = utils.BatchNormalization
|
| 268 |
+
if fine_tuning:
|
| 269 |
+
override_params['relu_fn'] = functools.partial(swish, use_native=False)
|
| 270 |
+
blocks_args, global_params = get_model_params(model_name, override_params)
|
| 271 |
+
|
| 272 |
+
if model_dir:
|
| 273 |
+
param_file = os.path.join(model_dir, 'model_params.txt')
|
| 274 |
+
if not tf.gfile.Exists(param_file):
|
| 275 |
+
if not tf.gfile.Exists(model_dir):
|
| 276 |
+
tf.gfile.MakeDirs(model_dir)
|
| 277 |
+
with tf.gfile.GFile(param_file, 'w') as f:
|
| 278 |
+
logging.info('writing to %s', param_file)
|
| 279 |
+
f.write('model_name= %s\n\n' % model_name)
|
| 280 |
+
f.write('global_params= %s\n\n' % str(global_params))
|
| 281 |
+
f.write('blocks_args= %s\n\n' % str(blocks_args))
|
| 282 |
+
|
| 283 |
+
with tf.variable_scope(model_name):
|
| 284 |
+
model = efficientnet_model.Model(blocks_args, global_params)
|
| 285 |
+
outputs = model(
|
| 286 |
+
images,
|
| 287 |
+
training=training,
|
| 288 |
+
features_only=features_only,
|
| 289 |
+
pooled_features_only=pooled_features_only)
|
| 290 |
+
if features_only:
|
| 291 |
+
outputs = tf.identity(outputs, 'features')
|
| 292 |
+
elif pooled_features_only:
|
| 293 |
+
outputs = tf.identity(outputs, 'pooled_features')
|
| 294 |
+
else:
|
| 295 |
+
outputs = tf.identity(outputs, 'logits')
|
| 296 |
+
return outputs, model.endpoints
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def build_model_base(images, model_name, training, override_params=None):
|
| 300 |
+
"""A helper functiion to create a base model and return global_pool.
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
images: input images tensor.
|
| 304 |
+
model_name: string, the predefined model name.
|
| 305 |
+
training: boolean, whether the model is constructed for training.
|
| 306 |
+
override_params: A dictionary of params for overriding. Fields must exist in
|
| 307 |
+
efficientnet_model.GlobalParams.
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
features: global pool features.
|
| 311 |
+
endpoints: the endpoints for each layer.
|
| 312 |
+
|
| 313 |
+
Raises:
|
| 314 |
+
When model_name specified an undefined model, raises NotImplementedError.
|
| 315 |
+
When override_params has invalid fields, raises ValueError.
|
| 316 |
+
"""
|
| 317 |
+
assert isinstance(images, tf.Tensor)
|
| 318 |
+
# For backward compatibility.
|
| 319 |
+
if override_params and override_params.get('drop_connect_rate', None):
|
| 320 |
+
override_params['survival_prob'] = 1 - override_params['drop_connect_rate']
|
| 321 |
+
|
| 322 |
+
blocks_args, global_params = get_model_params(model_name, override_params)
|
| 323 |
+
|
| 324 |
+
with tf.variable_scope(model_name):
|
| 325 |
+
model = efficientnet_model.Model(blocks_args, global_params)
|
| 326 |
+
features = model(images, training=training, features_only=True)
|
| 327 |
+
|
| 328 |
+
features = tf.identity(features, 'features')
|
| 329 |
+
return features, model.endpoints
|
external_data/original_tf/efficientnet_model.py
ADDED
|
@@ -0,0 +1,713 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Contains definitions for EfficientNet model.
|
| 16 |
+
|
| 17 |
+
[1] Mingxing Tan, Quoc V. Le
|
| 18 |
+
EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks.
|
| 19 |
+
ICML'19, https://arxiv.org/abs/1905.11946
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from __future__ import absolute_import
|
| 23 |
+
from __future__ import division
|
| 24 |
+
from __future__ import print_function
|
| 25 |
+
|
| 26 |
+
import collections
|
| 27 |
+
import functools
|
| 28 |
+
import math
|
| 29 |
+
|
| 30 |
+
from absl import logging
|
| 31 |
+
import numpy as np
|
| 32 |
+
import six
|
| 33 |
+
from six.moves import xrange
|
| 34 |
+
import tensorflow.compat.v1 as tf
|
| 35 |
+
|
| 36 |
+
import utils
|
| 37 |
+
# from condconv import condconv_layers
|
| 38 |
+
|
| 39 |
+
GlobalParams = collections.namedtuple('GlobalParams', [
|
| 40 |
+
'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 'data_format',
|
| 41 |
+
'num_classes', 'width_coefficient', 'depth_coefficient', 'depth_divisor',
|
| 42 |
+
'min_depth', 'survival_prob', 'relu_fn', 'batch_norm', 'use_se',
|
| 43 |
+
'local_pooling', 'condconv_num_experts', 'clip_projection_output',
|
| 44 |
+
'blocks_args'
|
| 45 |
+
])
|
| 46 |
+
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
|
| 47 |
+
|
| 48 |
+
BlockArgs = collections.namedtuple('BlockArgs', [
|
| 49 |
+
'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
|
| 50 |
+
'expand_ratio', 'id_skip', 'strides', 'se_ratio', 'conv_type', 'fused_conv',
|
| 51 |
+
'super_pixel', 'condconv'
|
| 52 |
+
])
|
| 53 |
+
# defaults will be a public argument for namedtuple in Python 3.7
|
| 54 |
+
# https://docs.python.org/3/library/collections.html#collections.namedtuple
|
| 55 |
+
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def conv_kernel_initializer(shape, dtype=None, partition_info=None):
|
| 59 |
+
"""Initialization for convolutional kernels.
|
| 60 |
+
|
| 61 |
+
The main difference with tf.variance_scaling_initializer is that
|
| 62 |
+
tf.variance_scaling_initializer uses a truncated normal with an uncorrected
|
| 63 |
+
standard deviation, whereas here we use a normal distribution. Similarly,
|
| 64 |
+
tf.initializers.variance_scaling uses a truncated normal with
|
| 65 |
+
a corrected standard deviation.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
shape: shape of variable
|
| 69 |
+
dtype: dtype of variable
|
| 70 |
+
partition_info: unused
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
an initialization for the variable
|
| 74 |
+
"""
|
| 75 |
+
del partition_info
|
| 76 |
+
kernel_height, kernel_width, _, out_filters = shape
|
| 77 |
+
fan_out = int(kernel_height * kernel_width * out_filters)
|
| 78 |
+
return tf.random_normal(
|
| 79 |
+
shape, mean=0.0, stddev=np.sqrt(2.0 / fan_out), dtype=dtype)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def dense_kernel_initializer(shape, dtype=None, partition_info=None):
|
| 83 |
+
"""Initialization for dense kernels.
|
| 84 |
+
|
| 85 |
+
This initialization is equal to
|
| 86 |
+
tf.variance_scaling_initializer(scale=1.0/3.0, mode='fan_out',
|
| 87 |
+
distribution='uniform').
|
| 88 |
+
It is written out explicitly here for clarity.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
shape: shape of variable
|
| 92 |
+
dtype: dtype of variable
|
| 93 |
+
partition_info: unused
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
an initialization for the variable
|
| 97 |
+
"""
|
| 98 |
+
del partition_info
|
| 99 |
+
init_range = 1.0 / np.sqrt(shape[1])
|
| 100 |
+
return tf.random_uniform(shape, -init_range, init_range, dtype=dtype)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def superpixel_kernel_initializer(shape, dtype='float32', partition_info=None):
|
| 104 |
+
"""Initializes superpixel kernels.
|
| 105 |
+
|
| 106 |
+
This is inspired by space-to-depth transformation that is mathematically
|
| 107 |
+
equivalent before and after the transformation. But we do the space-to-depth
|
| 108 |
+
via a convolution. Moreover, we make the layer trainable instead of direct
|
| 109 |
+
transform, we can initialization it this way so that the model can learn not
|
| 110 |
+
to do anything but keep it mathematically equivalent, when improving
|
| 111 |
+
performance.
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
shape: shape of variable
|
| 116 |
+
dtype: dtype of variable
|
| 117 |
+
partition_info: unused
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
an initialization for the variable
|
| 121 |
+
"""
|
| 122 |
+
del partition_info
|
| 123 |
+
# use input depth to make superpixel kernel.
|
| 124 |
+
depth = shape[-2]
|
| 125 |
+
filters = np.zeros([2, 2, depth, 4 * depth], dtype=dtype)
|
| 126 |
+
i = np.arange(2)
|
| 127 |
+
j = np.arange(2)
|
| 128 |
+
k = np.arange(depth)
|
| 129 |
+
mesh = np.array(np.meshgrid(i, j, k)).T.reshape(-1, 3).T
|
| 130 |
+
filters[
|
| 131 |
+
mesh[0],
|
| 132 |
+
mesh[1],
|
| 133 |
+
mesh[2],
|
| 134 |
+
4 * mesh[2] + 2 * mesh[0] + mesh[1]] = 1
|
| 135 |
+
return filters
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def round_filters(filters, global_params):
|
| 139 |
+
"""Round number of filters based on depth multiplier."""
|
| 140 |
+
orig_f = filters
|
| 141 |
+
multiplier = global_params.width_coefficient
|
| 142 |
+
divisor = global_params.depth_divisor
|
| 143 |
+
min_depth = global_params.min_depth
|
| 144 |
+
if not multiplier:
|
| 145 |
+
return filters
|
| 146 |
+
|
| 147 |
+
filters *= multiplier
|
| 148 |
+
min_depth = min_depth or divisor
|
| 149 |
+
new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
|
| 150 |
+
# Make sure that round down does not go down by more than 10%.
|
| 151 |
+
if new_filters < 0.9 * filters:
|
| 152 |
+
new_filters += divisor
|
| 153 |
+
logging.info('round_filter input=%s output=%s', orig_f, new_filters)
|
| 154 |
+
return int(new_filters)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def round_repeats(repeats, global_params):
|
| 158 |
+
"""Round number of filters based on depth multiplier."""
|
| 159 |
+
multiplier = global_params.depth_coefficient
|
| 160 |
+
if not multiplier:
|
| 161 |
+
return repeats
|
| 162 |
+
return int(math.ceil(multiplier * repeats))
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class MBConvBlock(tf.keras.layers.Layer):
|
| 166 |
+
"""A class of MBConv: Mobile Inverted Residual Bottleneck.
|
| 167 |
+
|
| 168 |
+
Attributes:
|
| 169 |
+
endpoints: dict. A list of internal tensors.
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
def __init__(self, block_args, global_params):
|
| 173 |
+
"""Initializes a MBConv block.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
block_args: BlockArgs, arguments to create a Block.
|
| 177 |
+
global_params: GlobalParams, a set of global parameters.
|
| 178 |
+
"""
|
| 179 |
+
super(MBConvBlock, self).__init__()
|
| 180 |
+
self._block_args = block_args
|
| 181 |
+
self._batch_norm_momentum = global_params.batch_norm_momentum
|
| 182 |
+
self._batch_norm_epsilon = global_params.batch_norm_epsilon
|
| 183 |
+
self._batch_norm = global_params.batch_norm
|
| 184 |
+
self._condconv_num_experts = global_params.condconv_num_experts
|
| 185 |
+
self._data_format = global_params.data_format
|
| 186 |
+
if self._data_format == 'channels_first':
|
| 187 |
+
self._channel_axis = 1
|
| 188 |
+
self._spatial_dims = [2, 3]
|
| 189 |
+
else:
|
| 190 |
+
self._channel_axis = -1
|
| 191 |
+
self._spatial_dims = [1, 2]
|
| 192 |
+
|
| 193 |
+
self._relu_fn = global_params.relu_fn or tf.nn.swish
|
| 194 |
+
self._has_se = (
|
| 195 |
+
global_params.use_se and self._block_args.se_ratio is not None and
|
| 196 |
+
0 < self._block_args.se_ratio <= 1)
|
| 197 |
+
|
| 198 |
+
self._clip_projection_output = global_params.clip_projection_output
|
| 199 |
+
|
| 200 |
+
self.endpoints = None
|
| 201 |
+
|
| 202 |
+
self.conv_cls = tf.layers.Conv2D
|
| 203 |
+
self.depthwise_conv_cls = utils.DepthwiseConv2D
|
| 204 |
+
if self._block_args.condconv:
|
| 205 |
+
self.conv_cls = functools.partial(
|
| 206 |
+
condconv_layers.CondConv2D, num_experts=self._condconv_num_experts)
|
| 207 |
+
self.depthwise_conv_cls = functools.partial(
|
| 208 |
+
condconv_layers.DepthwiseCondConv2D,
|
| 209 |
+
num_experts=self._condconv_num_experts)
|
| 210 |
+
|
| 211 |
+
# Builds the block accordings to arguments.
|
| 212 |
+
self._build()
|
| 213 |
+
|
| 214 |
+
def block_args(self):
|
| 215 |
+
return self._block_args
|
| 216 |
+
|
| 217 |
+
def _build(self):
|
| 218 |
+
"""Builds block according to the arguments."""
|
| 219 |
+
if self._block_args.super_pixel == 1:
|
| 220 |
+
self._superpixel = tf.layers.Conv2D(
|
| 221 |
+
self._block_args.input_filters,
|
| 222 |
+
kernel_size=[2, 2],
|
| 223 |
+
strides=[2, 2],
|
| 224 |
+
kernel_initializer=conv_kernel_initializer,
|
| 225 |
+
padding='same',
|
| 226 |
+
data_format=self._data_format,
|
| 227 |
+
use_bias=False)
|
| 228 |
+
self._bnsp = self._batch_norm(
|
| 229 |
+
axis=self._channel_axis,
|
| 230 |
+
momentum=self._batch_norm_momentum,
|
| 231 |
+
epsilon=self._batch_norm_epsilon)
|
| 232 |
+
|
| 233 |
+
if self._block_args.condconv:
|
| 234 |
+
# Add the example-dependent routing function
|
| 235 |
+
self._avg_pooling = tf.keras.layers.GlobalAveragePooling2D(
|
| 236 |
+
data_format=self._data_format)
|
| 237 |
+
self._routing_fn = tf.layers.Dense(
|
| 238 |
+
self._condconv_num_experts, activation=tf.nn.sigmoid)
|
| 239 |
+
|
| 240 |
+
filters = self._block_args.input_filters * self._block_args.expand_ratio
|
| 241 |
+
kernel_size = self._block_args.kernel_size
|
| 242 |
+
|
| 243 |
+
# Fused expansion phase. Called if using fused convolutions.
|
| 244 |
+
self._fused_conv = self.conv_cls(
|
| 245 |
+
filters=filters,
|
| 246 |
+
kernel_size=[kernel_size, kernel_size],
|
| 247 |
+
strides=self._block_args.strides,
|
| 248 |
+
kernel_initializer=conv_kernel_initializer,
|
| 249 |
+
padding='same',
|
| 250 |
+
data_format=self._data_format,
|
| 251 |
+
use_bias=False)
|
| 252 |
+
|
| 253 |
+
# Expansion phase. Called if not using fused convolutions and expansion
|
| 254 |
+
# phase is necessary.
|
| 255 |
+
self._expand_conv = self.conv_cls(
|
| 256 |
+
filters=filters,
|
| 257 |
+
kernel_size=[1, 1],
|
| 258 |
+
strides=[1, 1],
|
| 259 |
+
kernel_initializer=conv_kernel_initializer,
|
| 260 |
+
padding='same',
|
| 261 |
+
data_format=self._data_format,
|
| 262 |
+
use_bias=False)
|
| 263 |
+
self._bn0 = self._batch_norm(
|
| 264 |
+
axis=self._channel_axis,
|
| 265 |
+
momentum=self._batch_norm_momentum,
|
| 266 |
+
epsilon=self._batch_norm_epsilon)
|
| 267 |
+
|
| 268 |
+
# Depth-wise convolution phase. Called if not using fused convolutions.
|
| 269 |
+
self._depthwise_conv = self.depthwise_conv_cls(
|
| 270 |
+
kernel_size=[kernel_size, kernel_size],
|
| 271 |
+
strides=self._block_args.strides,
|
| 272 |
+
depthwise_initializer=conv_kernel_initializer,
|
| 273 |
+
padding='same',
|
| 274 |
+
data_format=self._data_format,
|
| 275 |
+
use_bias=False)
|
| 276 |
+
|
| 277 |
+
self._bn1 = self._batch_norm(
|
| 278 |
+
axis=self._channel_axis,
|
| 279 |
+
momentum=self._batch_norm_momentum,
|
| 280 |
+
epsilon=self._batch_norm_epsilon)
|
| 281 |
+
|
| 282 |
+
if self._has_se:
|
| 283 |
+
num_reduced_filters = max(
|
| 284 |
+
1, int(self._block_args.input_filters * self._block_args.se_ratio))
|
| 285 |
+
# Squeeze and Excitation layer.
|
| 286 |
+
self._se_reduce = tf.layers.Conv2D(
|
| 287 |
+
num_reduced_filters,
|
| 288 |
+
kernel_size=[1, 1],
|
| 289 |
+
strides=[1, 1],
|
| 290 |
+
kernel_initializer=conv_kernel_initializer,
|
| 291 |
+
padding='same',
|
| 292 |
+
data_format=self._data_format,
|
| 293 |
+
use_bias=True)
|
| 294 |
+
self._se_expand = tf.layers.Conv2D(
|
| 295 |
+
filters,
|
| 296 |
+
kernel_size=[1, 1],
|
| 297 |
+
strides=[1, 1],
|
| 298 |
+
kernel_initializer=conv_kernel_initializer,
|
| 299 |
+
padding='same',
|
| 300 |
+
data_format=self._data_format,
|
| 301 |
+
use_bias=True)
|
| 302 |
+
|
| 303 |
+
# Output phase.
|
| 304 |
+
filters = self._block_args.output_filters
|
| 305 |
+
self._project_conv = self.conv_cls(
|
| 306 |
+
filters=filters,
|
| 307 |
+
kernel_size=[1, 1],
|
| 308 |
+
strides=[1, 1],
|
| 309 |
+
kernel_initializer=conv_kernel_initializer,
|
| 310 |
+
padding='same',
|
| 311 |
+
data_format=self._data_format,
|
| 312 |
+
use_bias=False)
|
| 313 |
+
self._bn2 = self._batch_norm(
|
| 314 |
+
axis=self._channel_axis,
|
| 315 |
+
momentum=self._batch_norm_momentum,
|
| 316 |
+
epsilon=self._batch_norm_epsilon)
|
| 317 |
+
|
| 318 |
+
def _call_se(self, input_tensor):
|
| 319 |
+
"""Call Squeeze and Excitation layer.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
input_tensor: Tensor, a single input tensor for Squeeze/Excitation layer.
|
| 323 |
+
|
| 324 |
+
Returns:
|
| 325 |
+
A output tensor, which should have the same shape as input.
|
| 326 |
+
"""
|
| 327 |
+
se_tensor = tf.reduce_mean(input_tensor, self._spatial_dims, keepdims=True)
|
| 328 |
+
se_tensor = self._se_expand(self._relu_fn(self._se_reduce(se_tensor)))
|
| 329 |
+
logging.info('Built Squeeze and Excitation with tensor shape: %s',
|
| 330 |
+
(se_tensor.shape))
|
| 331 |
+
return tf.sigmoid(se_tensor) * input_tensor
|
| 332 |
+
|
| 333 |
+
def call(self, inputs, training=True, survival_prob=None):
|
| 334 |
+
"""Implementation of call().
|
| 335 |
+
|
| 336 |
+
Args:
|
| 337 |
+
inputs: the inputs tensor.
|
| 338 |
+
training: boolean, whether the model is constructed for training.
|
| 339 |
+
survival_prob: float, between 0 to 1, drop connect rate.
|
| 340 |
+
|
| 341 |
+
Returns:
|
| 342 |
+
A output tensor.
|
| 343 |
+
"""
|
| 344 |
+
logging.info('Block input: %s shape: %s', inputs.name, inputs.shape)
|
| 345 |
+
logging.info('Block input depth: %s output depth: %s',
|
| 346 |
+
self._block_args.input_filters,
|
| 347 |
+
self._block_args.output_filters)
|
| 348 |
+
|
| 349 |
+
x = inputs
|
| 350 |
+
|
| 351 |
+
fused_conv_fn = self._fused_conv
|
| 352 |
+
expand_conv_fn = self._expand_conv
|
| 353 |
+
depthwise_conv_fn = self._depthwise_conv
|
| 354 |
+
project_conv_fn = self._project_conv
|
| 355 |
+
|
| 356 |
+
if self._block_args.condconv:
|
| 357 |
+
pooled_inputs = self._avg_pooling(inputs)
|
| 358 |
+
routing_weights = self._routing_fn(pooled_inputs)
|
| 359 |
+
# Capture routing weights as additional input to CondConv layers
|
| 360 |
+
fused_conv_fn = functools.partial(
|
| 361 |
+
self._fused_conv, routing_weights=routing_weights)
|
| 362 |
+
expand_conv_fn = functools.partial(
|
| 363 |
+
self._expand_conv, routing_weights=routing_weights)
|
| 364 |
+
depthwise_conv_fn = functools.partial(
|
| 365 |
+
self._depthwise_conv, routing_weights=routing_weights)
|
| 366 |
+
project_conv_fn = functools.partial(
|
| 367 |
+
self._project_conv, routing_weights=routing_weights)
|
| 368 |
+
|
| 369 |
+
# creates conv 2x2 kernel
|
| 370 |
+
if self._block_args.super_pixel == 1:
|
| 371 |
+
with tf.variable_scope('super_pixel'):
|
| 372 |
+
x = self._relu_fn(
|
| 373 |
+
self._bnsp(self._superpixel(x), training=training))
|
| 374 |
+
logging.info(
|
| 375 |
+
'Block start with SuperPixel: %s shape: %s', x.name, x.shape)
|
| 376 |
+
|
| 377 |
+
if self._block_args.fused_conv:
|
| 378 |
+
# If use fused mbconv, skip expansion and use regular conv.
|
| 379 |
+
x = self._relu_fn(self._bn1(fused_conv_fn(x), training=training))
|
| 380 |
+
logging.info('Conv2D: %s shape: %s', x.name, x.shape)
|
| 381 |
+
else:
|
| 382 |
+
# Otherwise, first apply expansion and then apply depthwise conv.
|
| 383 |
+
if self._block_args.expand_ratio != 1:
|
| 384 |
+
x = self._relu_fn(self._bn0(expand_conv_fn(x), training=training))
|
| 385 |
+
logging.info('Expand: %s shape: %s', x.name, x.shape)
|
| 386 |
+
|
| 387 |
+
x = self._relu_fn(self._bn1(depthwise_conv_fn(x), training=training))
|
| 388 |
+
logging.info('DWConv: %s shape: %s', x.name, x.shape)
|
| 389 |
+
|
| 390 |
+
if self._has_se:
|
| 391 |
+
with tf.variable_scope('se'):
|
| 392 |
+
x = self._call_se(x)
|
| 393 |
+
|
| 394 |
+
self.endpoints = {'expansion_output': x}
|
| 395 |
+
|
| 396 |
+
x = self._bn2(project_conv_fn(x), training=training)
|
| 397 |
+
# Add identity so that quantization-aware training can insert quantization
|
| 398 |
+
# ops correctly.
|
| 399 |
+
x = tf.identity(x)
|
| 400 |
+
if self._clip_projection_output:
|
| 401 |
+
x = tf.clip_by_value(x, -6, 6)
|
| 402 |
+
if self._block_args.id_skip:
|
| 403 |
+
if all(
|
| 404 |
+
s == 1 for s in self._block_args.strides
|
| 405 |
+
) and self._block_args.input_filters == self._block_args.output_filters:
|
| 406 |
+
# Apply only if skip connection presents.
|
| 407 |
+
if survival_prob:
|
| 408 |
+
x = utils.drop_connect(x, training, survival_prob)
|
| 409 |
+
x = tf.add(x, inputs)
|
| 410 |
+
logging.info('Project: %s shape: %s', x.name, x.shape)
|
| 411 |
+
return x
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
class MBConvBlockWithoutDepthwise(MBConvBlock):
|
| 415 |
+
"""MBConv-like block without depthwise convolution and squeeze-and-excite."""
|
| 416 |
+
|
| 417 |
+
def _build(self):
|
| 418 |
+
"""Builds block according to the arguments."""
|
| 419 |
+
filters = self._block_args.input_filters * self._block_args.expand_ratio
|
| 420 |
+
if self._block_args.expand_ratio != 1:
|
| 421 |
+
# Expansion phase:
|
| 422 |
+
self._expand_conv = tf.layers.Conv2D(
|
| 423 |
+
filters,
|
| 424 |
+
kernel_size=[3, 3],
|
| 425 |
+
strides=[1, 1],
|
| 426 |
+
kernel_initializer=conv_kernel_initializer,
|
| 427 |
+
padding='same',
|
| 428 |
+
use_bias=False)
|
| 429 |
+
self._bn0 = self._batch_norm(
|
| 430 |
+
axis=self._channel_axis,
|
| 431 |
+
momentum=self._batch_norm_momentum,
|
| 432 |
+
epsilon=self._batch_norm_epsilon)
|
| 433 |
+
|
| 434 |
+
# Output phase:
|
| 435 |
+
filters = self._block_args.output_filters
|
| 436 |
+
self._project_conv = tf.layers.Conv2D(
|
| 437 |
+
filters,
|
| 438 |
+
kernel_size=[1, 1],
|
| 439 |
+
strides=self._block_args.strides,
|
| 440 |
+
kernel_initializer=conv_kernel_initializer,
|
| 441 |
+
padding='same',
|
| 442 |
+
use_bias=False)
|
| 443 |
+
self._bn1 = self._batch_norm(
|
| 444 |
+
axis=self._channel_axis,
|
| 445 |
+
momentum=self._batch_norm_momentum,
|
| 446 |
+
epsilon=self._batch_norm_epsilon)
|
| 447 |
+
|
| 448 |
+
def call(self, inputs, training=True, survival_prob=None):
|
| 449 |
+
"""Implementation of call().
|
| 450 |
+
|
| 451 |
+
Args:
|
| 452 |
+
inputs: the inputs tensor.
|
| 453 |
+
training: boolean, whether the model is constructed for training.
|
| 454 |
+
survival_prob: float, between 0 to 1, drop connect rate.
|
| 455 |
+
|
| 456 |
+
Returns:
|
| 457 |
+
A output tensor.
|
| 458 |
+
"""
|
| 459 |
+
logging.info('Block input: %s shape: %s', inputs.name, inputs.shape)
|
| 460 |
+
if self._block_args.expand_ratio != 1:
|
| 461 |
+
x = self._relu_fn(self._bn0(self._expand_conv(inputs), training=training))
|
| 462 |
+
else:
|
| 463 |
+
x = inputs
|
| 464 |
+
logging.info('Expand: %s shape: %s', x.name, x.shape)
|
| 465 |
+
|
| 466 |
+
self.endpoints = {'expansion_output': x}
|
| 467 |
+
|
| 468 |
+
x = self._bn1(self._project_conv(x), training=training)
|
| 469 |
+
# Add identity so that quantization-aware training can insert quantization
|
| 470 |
+
# ops correctly.
|
| 471 |
+
x = tf.identity(x)
|
| 472 |
+
if self._clip_projection_output:
|
| 473 |
+
x = tf.clip_by_value(x, -6, 6)
|
| 474 |
+
|
| 475 |
+
if self._block_args.id_skip:
|
| 476 |
+
if all(
|
| 477 |
+
s == 1 for s in self._block_args.strides
|
| 478 |
+
) and self._block_args.input_filters == self._block_args.output_filters:
|
| 479 |
+
# Apply only if skip connection presents.
|
| 480 |
+
if survival_prob:
|
| 481 |
+
x = utils.drop_connect(x, training, survival_prob)
|
| 482 |
+
x = tf.add(x, inputs)
|
| 483 |
+
logging.info('Project: %s shape: %s', x.name, x.shape)
|
| 484 |
+
return x
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
class Model(tf.keras.Model):
|
| 488 |
+
"""A class implements tf.keras.Model for MNAS-like model.
|
| 489 |
+
|
| 490 |
+
Reference: https://arxiv.org/abs/1807.11626
|
| 491 |
+
"""
|
| 492 |
+
|
| 493 |
+
def __init__(self, blocks_args=None, global_params=None):
|
| 494 |
+
"""Initializes an `Model` instance.
|
| 495 |
+
|
| 496 |
+
Args:
|
| 497 |
+
blocks_args: A list of BlockArgs to construct block modules.
|
| 498 |
+
global_params: GlobalParams, a set of global parameters.
|
| 499 |
+
|
| 500 |
+
Raises:
|
| 501 |
+
ValueError: when blocks_args is not specified as a list.
|
| 502 |
+
"""
|
| 503 |
+
super(Model, self).__init__()
|
| 504 |
+
if not isinstance(blocks_args, list):
|
| 505 |
+
raise ValueError('blocks_args should be a list.')
|
| 506 |
+
self._global_params = global_params
|
| 507 |
+
self._blocks_args = blocks_args
|
| 508 |
+
self._relu_fn = global_params.relu_fn or tf.nn.swish
|
| 509 |
+
self._batch_norm = global_params.batch_norm
|
| 510 |
+
|
| 511 |
+
self.endpoints = None
|
| 512 |
+
|
| 513 |
+
self._build()
|
| 514 |
+
|
| 515 |
+
def _get_conv_block(self, conv_type):
|
| 516 |
+
conv_block_map = {0: MBConvBlock, 1: MBConvBlockWithoutDepthwise}
|
| 517 |
+
return conv_block_map[conv_type]
|
| 518 |
+
|
| 519 |
+
def _build(self):
|
| 520 |
+
"""Builds a model."""
|
| 521 |
+
self._blocks = []
|
| 522 |
+
batch_norm_momentum = self._global_params.batch_norm_momentum
|
| 523 |
+
batch_norm_epsilon = self._global_params.batch_norm_epsilon
|
| 524 |
+
if self._global_params.data_format == 'channels_first':
|
| 525 |
+
channel_axis = 1
|
| 526 |
+
self._spatial_dims = [2, 3]
|
| 527 |
+
else:
|
| 528 |
+
channel_axis = -1
|
| 529 |
+
self._spatial_dims = [1, 2]
|
| 530 |
+
|
| 531 |
+
# Stem part.
|
| 532 |
+
self._conv_stem = tf.layers.Conv2D(
|
| 533 |
+
filters=round_filters(32, self._global_params),
|
| 534 |
+
kernel_size=[3, 3],
|
| 535 |
+
strides=[2, 2],
|
| 536 |
+
kernel_initializer=conv_kernel_initializer,
|
| 537 |
+
padding='same',
|
| 538 |
+
data_format=self._global_params.data_format,
|
| 539 |
+
use_bias=False)
|
| 540 |
+
self._bn0 = self._batch_norm(
|
| 541 |
+
axis=channel_axis,
|
| 542 |
+
momentum=batch_norm_momentum,
|
| 543 |
+
epsilon=batch_norm_epsilon)
|
| 544 |
+
|
| 545 |
+
# Builds blocks.
|
| 546 |
+
for block_args in self._blocks_args:
|
| 547 |
+
assert block_args.num_repeat > 0
|
| 548 |
+
assert block_args.super_pixel in [0, 1, 2]
|
| 549 |
+
# Update block input and output filters based on depth multiplier.
|
| 550 |
+
input_filters = round_filters(block_args.input_filters,
|
| 551 |
+
self._global_params)
|
| 552 |
+
output_filters = round_filters(block_args.output_filters,
|
| 553 |
+
self._global_params)
|
| 554 |
+
kernel_size = block_args.kernel_size
|
| 555 |
+
block_args = block_args._replace(
|
| 556 |
+
input_filters=input_filters,
|
| 557 |
+
output_filters=output_filters,
|
| 558 |
+
num_repeat=round_repeats(block_args.num_repeat, self._global_params))
|
| 559 |
+
|
| 560 |
+
# The first block needs to take care of stride and filter size increase.
|
| 561 |
+
conv_block = self._get_conv_block(block_args.conv_type)
|
| 562 |
+
if not block_args.super_pixel: # no super_pixel at all
|
| 563 |
+
self._blocks.append(conv_block(block_args, self._global_params))
|
| 564 |
+
else:
|
| 565 |
+
# if superpixel, adjust filters, kernels, and strides.
|
| 566 |
+
depth_factor = int(4 / block_args.strides[0] / block_args.strides[1])
|
| 567 |
+
block_args = block_args._replace(
|
| 568 |
+
input_filters=block_args.input_filters * depth_factor,
|
| 569 |
+
output_filters=block_args.output_filters * depth_factor,
|
| 570 |
+
kernel_size=((block_args.kernel_size + 1) // 2 if depth_factor > 1
|
| 571 |
+
else block_args.kernel_size))
|
| 572 |
+
# if the first block has stride-2 and super_pixel trandformation
|
| 573 |
+
if (block_args.strides[0] == 2 and block_args.strides[1] == 2):
|
| 574 |
+
block_args = block_args._replace(strides=[1, 1])
|
| 575 |
+
self._blocks.append(conv_block(block_args, self._global_params))
|
| 576 |
+
block_args = block_args._replace( # sp stops at stride-2
|
| 577 |
+
super_pixel=0,
|
| 578 |
+
input_filters=input_filters,
|
| 579 |
+
output_filters=output_filters,
|
| 580 |
+
kernel_size=kernel_size)
|
| 581 |
+
elif block_args.super_pixel == 1:
|
| 582 |
+
self._blocks.append(conv_block(block_args, self._global_params))
|
| 583 |
+
block_args = block_args._replace(super_pixel=2)
|
| 584 |
+
else:
|
| 585 |
+
self._blocks.append(conv_block(block_args, self._global_params))
|
| 586 |
+
if block_args.num_repeat > 1: # rest of blocks with the same block_arg
|
| 587 |
+
# pylint: disable=protected-access
|
| 588 |
+
block_args = block_args._replace(
|
| 589 |
+
input_filters=block_args.output_filters, strides=[1, 1])
|
| 590 |
+
# pylint: enable=protected-access
|
| 591 |
+
for _ in xrange(block_args.num_repeat - 1):
|
| 592 |
+
self._blocks.append(conv_block(block_args, self._global_params))
|
| 593 |
+
|
| 594 |
+
# Head part.
|
| 595 |
+
self._conv_head = tf.layers.Conv2D(
|
| 596 |
+
filters=round_filters(1280, self._global_params),
|
| 597 |
+
kernel_size=[1, 1],
|
| 598 |
+
strides=[1, 1],
|
| 599 |
+
kernel_initializer=conv_kernel_initializer,
|
| 600 |
+
padding='same',
|
| 601 |
+
use_bias=False)
|
| 602 |
+
self._bn1 = self._batch_norm(
|
| 603 |
+
axis=channel_axis,
|
| 604 |
+
momentum=batch_norm_momentum,
|
| 605 |
+
epsilon=batch_norm_epsilon)
|
| 606 |
+
|
| 607 |
+
self._avg_pooling = tf.keras.layers.GlobalAveragePooling2D(
|
| 608 |
+
data_format=self._global_params.data_format)
|
| 609 |
+
if self._global_params.num_classes:
|
| 610 |
+
self._fc = tf.layers.Dense(
|
| 611 |
+
self._global_params.num_classes,
|
| 612 |
+
kernel_initializer=dense_kernel_initializer)
|
| 613 |
+
else:
|
| 614 |
+
self._fc = None
|
| 615 |
+
|
| 616 |
+
if self._global_params.dropout_rate > 0:
|
| 617 |
+
self._dropout = tf.keras.layers.Dropout(self._global_params.dropout_rate)
|
| 618 |
+
else:
|
| 619 |
+
self._dropout = None
|
| 620 |
+
|
| 621 |
+
def call(self,
|
| 622 |
+
inputs,
|
| 623 |
+
training=True,
|
| 624 |
+
features_only=None,
|
| 625 |
+
pooled_features_only=False):
|
| 626 |
+
"""Implementation of call().
|
| 627 |
+
|
| 628 |
+
Args:
|
| 629 |
+
inputs: input tensors.
|
| 630 |
+
training: boolean, whether the model is constructed for training.
|
| 631 |
+
features_only: build the base feature network only.
|
| 632 |
+
pooled_features_only: build the base network for features extraction
|
| 633 |
+
(after 1x1 conv layer and global pooling, but before dropout and fc
|
| 634 |
+
head).
|
| 635 |
+
|
| 636 |
+
Returns:
|
| 637 |
+
output tensors.
|
| 638 |
+
"""
|
| 639 |
+
outputs = None
|
| 640 |
+
self.endpoints = {}
|
| 641 |
+
reduction_idx = 0
|
| 642 |
+
# Calls Stem layers
|
| 643 |
+
with tf.variable_scope('stem'):
|
| 644 |
+
outputs = self._relu_fn(
|
| 645 |
+
self._bn0(self._conv_stem(inputs), training=training))
|
| 646 |
+
logging.info('Built stem layers with output shape: %s', outputs.shape)
|
| 647 |
+
self.endpoints['stem'] = outputs
|
| 648 |
+
|
| 649 |
+
# Calls blocks.
|
| 650 |
+
for idx, block in enumerate(self._blocks):
|
| 651 |
+
is_reduction = False # reduction flag for blocks after the stem layer
|
| 652 |
+
# If the first block has super-pixel (space-to-depth) layer, then stem is
|
| 653 |
+
# the first reduction point.
|
| 654 |
+
if (block.block_args().super_pixel == 1 and idx == 0):
|
| 655 |
+
reduction_idx += 1
|
| 656 |
+
self.endpoints['reduction_%s' % reduction_idx] = outputs
|
| 657 |
+
|
| 658 |
+
elif ((idx == len(self._blocks) - 1) or
|
| 659 |
+
self._blocks[idx + 1].block_args().strides[0] > 1):
|
| 660 |
+
is_reduction = True
|
| 661 |
+
reduction_idx += 1
|
| 662 |
+
|
| 663 |
+
with tf.variable_scope('blocks_%s' % idx):
|
| 664 |
+
survival_prob = self._global_params.survival_prob
|
| 665 |
+
if survival_prob:
|
| 666 |
+
drop_rate = 1.0 - survival_prob
|
| 667 |
+
survival_prob = 1.0 - drop_rate * float(idx) / len(self._blocks)
|
| 668 |
+
logging.info('block_%s survival_prob: %s', idx, survival_prob)
|
| 669 |
+
outputs = block.call(
|
| 670 |
+
outputs, training=training, survival_prob=survival_prob)
|
| 671 |
+
self.endpoints['block_%s' % idx] = outputs
|
| 672 |
+
if is_reduction:
|
| 673 |
+
self.endpoints['reduction_%s' % reduction_idx] = outputs
|
| 674 |
+
if block.endpoints:
|
| 675 |
+
for k, v in six.iteritems(block.endpoints):
|
| 676 |
+
self.endpoints['block_%s/%s' % (idx, k)] = v
|
| 677 |
+
if is_reduction:
|
| 678 |
+
self.endpoints['reduction_%s/%s' % (reduction_idx, k)] = v
|
| 679 |
+
self.endpoints['features'] = outputs
|
| 680 |
+
|
| 681 |
+
if not features_only:
|
| 682 |
+
# Calls final layers and returns logits.
|
| 683 |
+
with tf.variable_scope('head'):
|
| 684 |
+
outputs = self._relu_fn(
|
| 685 |
+
self._bn1(self._conv_head(outputs), training=training))
|
| 686 |
+
self.endpoints['head_1x1'] = outputs
|
| 687 |
+
|
| 688 |
+
if self._global_params.local_pooling:
|
| 689 |
+
shape = outputs.get_shape().as_list()
|
| 690 |
+
kernel_size = [
|
| 691 |
+
1, shape[self._spatial_dims[0]], shape[self._spatial_dims[1]], 1]
|
| 692 |
+
outputs = tf.nn.avg_pool(
|
| 693 |
+
outputs, ksize=kernel_size, strides=[1, 1, 1, 1], padding='VALID')
|
| 694 |
+
self.endpoints['pooled_features'] = outputs
|
| 695 |
+
if not pooled_features_only:
|
| 696 |
+
if self._dropout:
|
| 697 |
+
outputs = self._dropout(outputs, training=training)
|
| 698 |
+
self.endpoints['global_pool'] = outputs
|
| 699 |
+
if self._fc:
|
| 700 |
+
outputs = tf.squeeze(outputs, self._spatial_dims)
|
| 701 |
+
outputs = self._fc(outputs)
|
| 702 |
+
self.endpoints['head'] = outputs
|
| 703 |
+
else:
|
| 704 |
+
outputs = self._avg_pooling(outputs)
|
| 705 |
+
self.endpoints['pooled_features'] = outputs
|
| 706 |
+
if not pooled_features_only:
|
| 707 |
+
if self._dropout:
|
| 708 |
+
outputs = self._dropout(outputs, training=training)
|
| 709 |
+
self.endpoints['global_pool'] = outputs
|
| 710 |
+
if self._fc:
|
| 711 |
+
outputs = self._fc(outputs)
|
| 712 |
+
self.endpoints['head'] = outputs
|
| 713 |
+
return outputs
|
external_data/original_tf/eval_ckpt_main.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Eval checkpoint driver.
|
| 16 |
+
|
| 17 |
+
This is an example evaluation script for users to understand the EfficientNet
|
| 18 |
+
model checkpoints on CPU. To serve EfficientNet, please consider to export a
|
| 19 |
+
`SavedModel` from checkpoints and use tf-serving to serve.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from __future__ import absolute_import
|
| 23 |
+
from __future__ import division
|
| 24 |
+
from __future__ import print_function
|
| 25 |
+
|
| 26 |
+
import json
|
| 27 |
+
import sys
|
| 28 |
+
from absl import app
|
| 29 |
+
from absl import flags
|
| 30 |
+
import numpy as np
|
| 31 |
+
import tensorflow as tf
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
import efficientnet_builder
|
| 35 |
+
import preprocessing
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
flags.DEFINE_string('model_name', 'efficientnet-b0', 'Model name to eval.')
|
| 39 |
+
flags.DEFINE_string('runmode', 'examples', 'Running mode: examples or imagenet')
|
| 40 |
+
flags.DEFINE_string('imagenet_eval_glob', None,
|
| 41 |
+
'Imagenet eval image glob, '
|
| 42 |
+
'such as /imagenet/ILSVRC2012*.JPEG')
|
| 43 |
+
flags.DEFINE_string('imagenet_eval_label', None,
|
| 44 |
+
'Imagenet eval label file path, '
|
| 45 |
+
'such as /imagenet/ILSVRC2012_validation_ground_truth.txt')
|
| 46 |
+
flags.DEFINE_string('ckpt_dir', '/tmp/ckpt/', 'Checkpoint folders')
|
| 47 |
+
flags.DEFINE_string('example_img', '/tmp/panda.jpg',
|
| 48 |
+
'Filepath for a single example image.')
|
| 49 |
+
flags.DEFINE_string('labels_map_file', '/tmp/labels_map.txt',
|
| 50 |
+
'Labels map from label id to its meaning.')
|
| 51 |
+
flags.DEFINE_integer('num_images', 5000,
|
| 52 |
+
'Number of images to eval. Use -1 to eval all images.')
|
| 53 |
+
FLAGS = flags.FLAGS
|
| 54 |
+
|
| 55 |
+
MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255]
|
| 56 |
+
STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class EvalCkptDriver(object):
|
| 60 |
+
"""A driver for running eval inference.
|
| 61 |
+
|
| 62 |
+
Attributes:
|
| 63 |
+
model_name: str. Model name to eval.
|
| 64 |
+
batch_size: int. Eval batch size.
|
| 65 |
+
num_classes: int. Number of classes, default to 1000 for ImageNet.
|
| 66 |
+
image_size: int. Input image size, determined by model name.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(self, model_name='efficientnet-b0', batch_size=1):
|
| 70 |
+
"""Initialize internal variables."""
|
| 71 |
+
self.model_name = model_name
|
| 72 |
+
self.batch_size = batch_size
|
| 73 |
+
self.num_classes = 1000
|
| 74 |
+
# Model Scaling parameters
|
| 75 |
+
_, _, self.image_size, _ = efficientnet_builder.efficientnet_params(
|
| 76 |
+
model_name)
|
| 77 |
+
|
| 78 |
+
def restore_model(self, sess, ckpt_dir):
|
| 79 |
+
"""Restore variables from checkpoint dir."""
|
| 80 |
+
checkpoint = tf.train.latest_checkpoint(ckpt_dir)
|
| 81 |
+
ema = tf.train.ExponentialMovingAverage(decay=0.9999)
|
| 82 |
+
ema_vars = tf.trainable_variables() + tf.get_collection('moving_vars')
|
| 83 |
+
for v in tf.global_variables():
|
| 84 |
+
if 'moving_mean' in v.name or 'moving_variance' in v.name:
|
| 85 |
+
ema_vars.append(v)
|
| 86 |
+
ema_vars = list(set(ema_vars))
|
| 87 |
+
var_dict = ema.variables_to_restore(ema_vars)
|
| 88 |
+
saver = tf.train.Saver(var_dict, max_to_keep=1)
|
| 89 |
+
saver.restore(sess, checkpoint)
|
| 90 |
+
|
| 91 |
+
def build_model(self, features, is_training):
|
| 92 |
+
"""Build model with input features."""
|
| 93 |
+
features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype)
|
| 94 |
+
features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype)
|
| 95 |
+
logits, _ = efficientnet_builder.build_model(
|
| 96 |
+
features, self.model_name, is_training)
|
| 97 |
+
probs = tf.nn.softmax(logits)
|
| 98 |
+
probs = tf.squeeze(probs)
|
| 99 |
+
return probs
|
| 100 |
+
|
| 101 |
+
def build_dataset(self, filenames, labels, is_training):
|
| 102 |
+
"""Build input dataset."""
|
| 103 |
+
filenames = tf.constant(filenames)
|
| 104 |
+
labels = tf.constant(labels)
|
| 105 |
+
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
|
| 106 |
+
|
| 107 |
+
def _parse_function(filename, label):
|
| 108 |
+
image_string = tf.read_file(filename)
|
| 109 |
+
image_decoded = preprocessing.preprocess_image(
|
| 110 |
+
image_string, is_training, self.image_size)
|
| 111 |
+
image = tf.cast(image_decoded, tf.float32)
|
| 112 |
+
return image, label
|
| 113 |
+
|
| 114 |
+
dataset = dataset.map(_parse_function)
|
| 115 |
+
dataset = dataset.batch(self.batch_size)
|
| 116 |
+
|
| 117 |
+
iterator = dataset.make_one_shot_iterator()
|
| 118 |
+
images, labels = iterator.get_next()
|
| 119 |
+
return images, labels
|
| 120 |
+
|
| 121 |
+
def run_inference(self, ckpt_dir, image_files, labels):
|
| 122 |
+
"""Build and run inference on the target images and labels."""
|
| 123 |
+
with tf.Graph().as_default(), tf.Session() as sess:
|
| 124 |
+
images, labels = self.build_dataset(image_files, labels, False)
|
| 125 |
+
probs = self.build_model(images, is_training=False)
|
| 126 |
+
|
| 127 |
+
sess.run(tf.global_variables_initializer())
|
| 128 |
+
self.restore_model(sess, ckpt_dir)
|
| 129 |
+
|
| 130 |
+
prediction_idx = []
|
| 131 |
+
prediction_prob = []
|
| 132 |
+
for _ in range(len(image_files) // self.batch_size):
|
| 133 |
+
out_probs = sess.run(probs)
|
| 134 |
+
idx = np.argsort(out_probs)[::-1]
|
| 135 |
+
prediction_idx.append(idx[:5])
|
| 136 |
+
prediction_prob.append([out_probs[pid] for pid in idx[:5]])
|
| 137 |
+
|
| 138 |
+
# Return the top 5 predictions (idx and prob) for each image.
|
| 139 |
+
return prediction_idx, prediction_prob
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def eval_example_images(model_name, ckpt_dir, image_files, labels_map_file):
|
| 143 |
+
"""Eval a list of example images.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
model_name: str. The name of model to eval.
|
| 147 |
+
ckpt_dir: str. Checkpoint directory path.
|
| 148 |
+
image_files: List[str]. A list of image file paths.
|
| 149 |
+
labels_map_file: str. The labels map file path.
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
A tuple (pred_idx, and pred_prob), where pred_idx is the top 5 prediction
|
| 153 |
+
index and pred_prob is the top 5 prediction probability.
|
| 154 |
+
"""
|
| 155 |
+
eval_ckpt_driver = EvalCkptDriver(model_name)
|
| 156 |
+
classes = json.loads(tf.gfile.Open(labels_map_file).read())
|
| 157 |
+
pred_idx, pred_prob = eval_ckpt_driver.run_inference(
|
| 158 |
+
ckpt_dir, image_files, [0] * len(image_files))
|
| 159 |
+
for i in range(len(image_files)):
|
| 160 |
+
print('predicted class for image {}: '.format(image_files[i]))
|
| 161 |
+
for j, idx in enumerate(pred_idx[i]):
|
| 162 |
+
print(' -> top_{} ({:4.2f}%): {} '.format(
|
| 163 |
+
j, pred_prob[i][j] * 100, classes[str(idx)]))
|
| 164 |
+
return pred_idx, pred_prob
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def eval_imagenet(model_name,
|
| 168 |
+
ckpt_dir,
|
| 169 |
+
imagenet_eval_glob,
|
| 170 |
+
imagenet_eval_label,
|
| 171 |
+
num_images):
|
| 172 |
+
"""Eval ImageNet images and report top1/top5 accuracy.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
model_name: str. The name of model to eval.
|
| 176 |
+
ckpt_dir: str. Checkpoint directory path.
|
| 177 |
+
imagenet_eval_glob: str. File path glob for all eval images.
|
| 178 |
+
imagenet_eval_label: str. File path for eval label.
|
| 179 |
+
num_images: int. Number of images to eval: -1 means eval the whole dataset.
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
A tuple (top1, top5) for top1 and top5 accuracy.
|
| 183 |
+
"""
|
| 184 |
+
eval_ckpt_driver = EvalCkptDriver(model_name)
|
| 185 |
+
imagenet_val_labels = [int(i) for i in tf.gfile.GFile(imagenet_eval_label)]
|
| 186 |
+
imagenet_filenames = sorted(tf.gfile.Glob(imagenet_eval_glob))
|
| 187 |
+
if num_images < 0:
|
| 188 |
+
num_images = len(imagenet_filenames)
|
| 189 |
+
image_files = imagenet_filenames[:num_images]
|
| 190 |
+
labels = imagenet_val_labels[:num_images]
|
| 191 |
+
|
| 192 |
+
pred_idx, _ = eval_ckpt_driver.run_inference(ckpt_dir, image_files, labels)
|
| 193 |
+
top1_cnt, top5_cnt = 0.0, 0.0
|
| 194 |
+
for i, label in enumerate(labels):
|
| 195 |
+
top1_cnt += label in pred_idx[i][:1]
|
| 196 |
+
top5_cnt += label in pred_idx[i][:5]
|
| 197 |
+
if i % 100 == 0:
|
| 198 |
+
print('Step {}: top1_acc = {:4.2f}% top5_acc = {:4.2f}%'.format(
|
| 199 |
+
i, 100 * top1_cnt / (i + 1), 100 * top5_cnt / (i + 1)))
|
| 200 |
+
sys.stdout.flush()
|
| 201 |
+
top1, top5 = 100 * top1_cnt / num_images, 100 * top5_cnt / num_images
|
| 202 |
+
print('Final: top1_acc = {:4.2f}% top5_acc = {:4.2f}%'.format(top1, top5))
|
| 203 |
+
return top1, top5
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def main(unused_argv):
|
| 207 |
+
tf.logging.set_verbosity(tf.logging.ERROR)
|
| 208 |
+
if FLAGS.runmode == 'examples':
|
| 209 |
+
# Run inference for an example image.
|
| 210 |
+
eval_example_images(FLAGS.model_name, FLAGS.ckpt_dir, [FLAGS.example_img],
|
| 211 |
+
FLAGS.labels_map_file)
|
| 212 |
+
elif FLAGS.runmode == 'imagenet':
|
| 213 |
+
# Run inference for imagenet.
|
| 214 |
+
eval_imagenet(FLAGS.model_name, FLAGS.ckpt_dir, FLAGS.imagenet_eval_glob,
|
| 215 |
+
FLAGS.imagenet_eval_label, FLAGS.num_images)
|
| 216 |
+
else:
|
| 217 |
+
print('must specify runmode: examples or imagenet')
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
if __name__ == '__main__':
|
| 221 |
+
app.run(main)
|
external_data/original_tf/preprocessing.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""ImageNet preprocessing."""
|
| 16 |
+
from __future__ import absolute_import
|
| 17 |
+
from __future__ import division
|
| 18 |
+
from __future__ import print_function
|
| 19 |
+
|
| 20 |
+
from absl import logging
|
| 21 |
+
|
| 22 |
+
import tensorflow.compat.v1 as tf
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
IMAGE_SIZE = 224
|
| 26 |
+
CROP_PADDING = 32
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def distorted_bounding_box_crop(image_bytes,
|
| 30 |
+
bbox,
|
| 31 |
+
min_object_covered=0.1,
|
| 32 |
+
aspect_ratio_range=(0.75, 1.33),
|
| 33 |
+
area_range=(0.05, 1.0),
|
| 34 |
+
max_attempts=100,
|
| 35 |
+
scope=None):
|
| 36 |
+
"""Generates cropped_image using one of the bboxes randomly distorted.
|
| 37 |
+
|
| 38 |
+
See `tf.image.sample_distorted_bounding_box` for more documentation.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
image_bytes: `Tensor` of binary image data.
|
| 42 |
+
bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
|
| 43 |
+
where each coordinate is [0, 1) and the coordinates are arranged
|
| 44 |
+
as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
|
| 45 |
+
image.
|
| 46 |
+
min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
|
| 47 |
+
area of the image must contain at least this fraction of any bounding
|
| 48 |
+
box supplied.
|
| 49 |
+
aspect_ratio_range: An optional list of `float`s. The cropped area of the
|
| 50 |
+
image must have an aspect ratio = width / height within this range.
|
| 51 |
+
area_range: An optional list of `float`s. The cropped area of the image
|
| 52 |
+
must contain a fraction of the supplied image within in this range.
|
| 53 |
+
max_attempts: An optional `int`. Number of attempts at generating a cropped
|
| 54 |
+
region of the image of the specified constraints. After `max_attempts`
|
| 55 |
+
failures, return the entire image.
|
| 56 |
+
scope: Optional `str` for name scope.
|
| 57 |
+
Returns:
|
| 58 |
+
cropped image `Tensor`
|
| 59 |
+
"""
|
| 60 |
+
with tf.name_scope(scope, 'distorted_bounding_box_crop', [image_bytes, bbox]):
|
| 61 |
+
shape = tf.image.extract_jpeg_shape(image_bytes)
|
| 62 |
+
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
|
| 63 |
+
shape,
|
| 64 |
+
bounding_boxes=bbox,
|
| 65 |
+
min_object_covered=min_object_covered,
|
| 66 |
+
aspect_ratio_range=aspect_ratio_range,
|
| 67 |
+
area_range=area_range,
|
| 68 |
+
max_attempts=max_attempts,
|
| 69 |
+
use_image_if_no_bounding_boxes=True)
|
| 70 |
+
bbox_begin, bbox_size, _ = sample_distorted_bounding_box
|
| 71 |
+
|
| 72 |
+
# Crop the image to the specified bounding box.
|
| 73 |
+
offset_y, offset_x, _ = tf.unstack(bbox_begin)
|
| 74 |
+
target_height, target_width, _ = tf.unstack(bbox_size)
|
| 75 |
+
crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
|
| 76 |
+
image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
|
| 77 |
+
|
| 78 |
+
return image
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _at_least_x_are_equal(a, b, x):
|
| 82 |
+
"""At least `x` of `a` and `b` `Tensors` are equal."""
|
| 83 |
+
match = tf.equal(a, b)
|
| 84 |
+
match = tf.cast(match, tf.int32)
|
| 85 |
+
return tf.greater_equal(tf.reduce_sum(match), x)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _decode_and_random_crop(image_bytes, image_size):
|
| 89 |
+
"""Make a random crop of image_size."""
|
| 90 |
+
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
|
| 91 |
+
image = distorted_bounding_box_crop(
|
| 92 |
+
image_bytes,
|
| 93 |
+
bbox,
|
| 94 |
+
min_object_covered=0.1,
|
| 95 |
+
aspect_ratio_range=(3. / 4, 4. / 3.),
|
| 96 |
+
area_range=(0.08, 1.0),
|
| 97 |
+
max_attempts=10,
|
| 98 |
+
scope=None)
|
| 99 |
+
original_shape = tf.image.extract_jpeg_shape(image_bytes)
|
| 100 |
+
bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3)
|
| 101 |
+
|
| 102 |
+
image = tf.cond(
|
| 103 |
+
bad,
|
| 104 |
+
lambda: _decode_and_center_crop(image_bytes, image_size),
|
| 105 |
+
lambda: tf.image.resize_bicubic([image], # pylint: disable=g-long-lambda
|
| 106 |
+
[image_size, image_size])[0])
|
| 107 |
+
|
| 108 |
+
return image
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _decode_and_center_crop(image_bytes, image_size):
|
| 112 |
+
"""Crops to center of image with padding then scales image_size."""
|
| 113 |
+
shape = tf.image.extract_jpeg_shape(image_bytes)
|
| 114 |
+
image_height = shape[0]
|
| 115 |
+
image_width = shape[1]
|
| 116 |
+
|
| 117 |
+
padded_center_crop_size = tf.cast(
|
| 118 |
+
((image_size / (image_size + CROP_PADDING)) *
|
| 119 |
+
tf.cast(tf.minimum(image_height, image_width), tf.float32)),
|
| 120 |
+
tf.int32)
|
| 121 |
+
|
| 122 |
+
offset_height = ((image_height - padded_center_crop_size) + 1) // 2
|
| 123 |
+
offset_width = ((image_width - padded_center_crop_size) + 1) // 2
|
| 124 |
+
crop_window = tf.stack([offset_height, offset_width,
|
| 125 |
+
padded_center_crop_size, padded_center_crop_size])
|
| 126 |
+
image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
|
| 127 |
+
image = tf.image.resize_bicubic([image], [image_size, image_size])[0]
|
| 128 |
+
return image
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def _flip(image):
|
| 132 |
+
"""Random horizontal image flip."""
|
| 133 |
+
image = tf.image.random_flip_left_right(image)
|
| 134 |
+
return image
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def preprocess_for_train(image_bytes, use_bfloat16, image_size=IMAGE_SIZE,
|
| 138 |
+
augment_name=None,
|
| 139 |
+
randaug_num_layers=None, randaug_magnitude=None):
|
| 140 |
+
"""Preprocesses the given image for evaluation.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
image_bytes: `Tensor` representing an image binary of arbitrary size.
|
| 144 |
+
use_bfloat16: `bool` for whether to use bfloat16.
|
| 145 |
+
image_size: image size.
|
| 146 |
+
augment_name: `string` that is the name of the augmentation method
|
| 147 |
+
to apply to the image. `autoaugment` if AutoAugment is to be used or
|
| 148 |
+
`randaugment` if RandAugment is to be used. If the value is `None` no
|
| 149 |
+
augmentation method will be applied applied. See autoaugment.py for more
|
| 150 |
+
details.
|
| 151 |
+
randaug_num_layers: 'int', if RandAug is used, what should the number of
|
| 152 |
+
layers be. See autoaugment.py for detailed description.
|
| 153 |
+
randaug_magnitude: 'int', if RandAug is used, what should the magnitude
|
| 154 |
+
be. See autoaugment.py for detailed description.
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
A preprocessed image `Tensor`.
|
| 158 |
+
"""
|
| 159 |
+
image = _decode_and_random_crop(image_bytes, image_size)
|
| 160 |
+
image = _flip(image)
|
| 161 |
+
image = tf.reshape(image, [image_size, image_size, 3])
|
| 162 |
+
|
| 163 |
+
image = tf.image.convert_image_dtype(
|
| 164 |
+
image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32)
|
| 165 |
+
|
| 166 |
+
if augment_name:
|
| 167 |
+
try:
|
| 168 |
+
import autoaugment # pylint: disable=g-import-not-at-top
|
| 169 |
+
except ImportError as e:
|
| 170 |
+
logging.exception('Autoaugment is not supported in TF 2.x.')
|
| 171 |
+
raise e
|
| 172 |
+
|
| 173 |
+
logging.info('Apply AutoAugment policy %s', augment_name)
|
| 174 |
+
input_image_type = image.dtype
|
| 175 |
+
image = tf.clip_by_value(image, 0.0, 255.0)
|
| 176 |
+
image = tf.cast(image, dtype=tf.uint8)
|
| 177 |
+
|
| 178 |
+
if augment_name == 'autoaugment':
|
| 179 |
+
logging.info('Apply AutoAugment policy %s', augment_name)
|
| 180 |
+
image = autoaugment.distort_image_with_autoaugment(image, 'v0')
|
| 181 |
+
elif augment_name == 'randaugment':
|
| 182 |
+
image = autoaugment.distort_image_with_randaugment(
|
| 183 |
+
image, randaug_num_layers, randaug_magnitude)
|
| 184 |
+
else:
|
| 185 |
+
raise ValueError('Invalid value for augment_name: %s' % (augment_name))
|
| 186 |
+
|
| 187 |
+
image = tf.cast(image, dtype=input_image_type)
|
| 188 |
+
return image
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE):
|
| 192 |
+
"""Preprocesses the given image for evaluation.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
image_bytes: `Tensor` representing an image binary of arbitrary size.
|
| 196 |
+
use_bfloat16: `bool` for whether to use bfloat16.
|
| 197 |
+
image_size: image size.
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
A preprocessed image `Tensor`.
|
| 201 |
+
"""
|
| 202 |
+
image = _decode_and_center_crop(image_bytes, image_size)
|
| 203 |
+
image = tf.reshape(image, [image_size, image_size, 3])
|
| 204 |
+
image = tf.image.convert_image_dtype(
|
| 205 |
+
image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32)
|
| 206 |
+
return image
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def preprocess_image(image_bytes,
|
| 210 |
+
is_training=False,
|
| 211 |
+
use_bfloat16=False,
|
| 212 |
+
image_size=IMAGE_SIZE,
|
| 213 |
+
augment_name=None,
|
| 214 |
+
randaug_num_layers=None,
|
| 215 |
+
randaug_magnitude=None):
|
| 216 |
+
"""Preprocesses the given image.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
image_bytes: `Tensor` representing an image binary of arbitrary size.
|
| 220 |
+
is_training: `bool` for whether the preprocessing is for training.
|
| 221 |
+
use_bfloat16: `bool` for whether to use bfloat16.
|
| 222 |
+
image_size: image size.
|
| 223 |
+
augment_name: `string` that is the name of the augmentation method
|
| 224 |
+
to apply to the image. `autoaugment` if AutoAugment is to be used or
|
| 225 |
+
`randaugment` if RandAugment is to be used. If the value is `None` no
|
| 226 |
+
augmentation method will be applied applied. See autoaugment.py for more
|
| 227 |
+
details.
|
| 228 |
+
randaug_num_layers: 'int', if RandAug is used, what should the number of
|
| 229 |
+
layers be. See autoaugment.py for detailed description.
|
| 230 |
+
randaug_magnitude: 'int', if RandAug is used, what should the magnitude
|
| 231 |
+
be. See autoaugment.py for detailed description.
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
A preprocessed image `Tensor` with value range of [0, 255].
|
| 235 |
+
"""
|
| 236 |
+
if is_training:
|
| 237 |
+
return preprocess_for_train(
|
| 238 |
+
image_bytes, use_bfloat16, image_size, augment_name,
|
| 239 |
+
randaug_num_layers, randaug_magnitude)
|
| 240 |
+
else:
|
| 241 |
+
return preprocess_for_eval(image_bytes, use_bfloat16, image_size)
|
external_data/original_tf/utils.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Model utilities."""
|
| 16 |
+
|
| 17 |
+
from __future__ import absolute_import
|
| 18 |
+
from __future__ import division
|
| 19 |
+
from __future__ import print_function
|
| 20 |
+
|
| 21 |
+
import json
|
| 22 |
+
import os
|
| 23 |
+
import sys
|
| 24 |
+
|
| 25 |
+
from absl import logging
|
| 26 |
+
import numpy as np
|
| 27 |
+
import tensorflow.compat.v1 as tf
|
| 28 |
+
|
| 29 |
+
from tensorflow.python.tpu import tpu_function # pylint:disable=g-direct-tensorflow-import
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def build_learning_rate(initial_lr,
|
| 33 |
+
global_step,
|
| 34 |
+
steps_per_epoch=None,
|
| 35 |
+
lr_decay_type='exponential',
|
| 36 |
+
decay_factor=0.97,
|
| 37 |
+
decay_epochs=2.4,
|
| 38 |
+
total_steps=None,
|
| 39 |
+
warmup_epochs=5):
|
| 40 |
+
"""Build learning rate."""
|
| 41 |
+
if lr_decay_type == 'exponential':
|
| 42 |
+
assert steps_per_epoch is not None
|
| 43 |
+
decay_steps = steps_per_epoch * decay_epochs
|
| 44 |
+
lr = tf.train.exponential_decay(
|
| 45 |
+
initial_lr, global_step, decay_steps, decay_factor, staircase=True)
|
| 46 |
+
elif lr_decay_type == 'cosine':
|
| 47 |
+
assert total_steps is not None
|
| 48 |
+
lr = 0.5 * initial_lr * (
|
| 49 |
+
1 + tf.cos(np.pi * tf.cast(global_step, tf.float32) / total_steps))
|
| 50 |
+
elif lr_decay_type == 'constant':
|
| 51 |
+
lr = initial_lr
|
| 52 |
+
else:
|
| 53 |
+
assert False, 'Unknown lr_decay_type : %s' % lr_decay_type
|
| 54 |
+
|
| 55 |
+
if warmup_epochs:
|
| 56 |
+
logging.info('Learning rate warmup_epochs: %d', warmup_epochs)
|
| 57 |
+
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
| 58 |
+
warmup_lr = (
|
| 59 |
+
initial_lr * tf.cast(global_step, tf.float32) / tf.cast(
|
| 60 |
+
warmup_steps, tf.float32))
|
| 61 |
+
lr = tf.cond(global_step < warmup_steps, lambda: warmup_lr, lambda: lr)
|
| 62 |
+
|
| 63 |
+
return lr
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def build_optimizer(learning_rate,
|
| 67 |
+
optimizer_name='rmsprop',
|
| 68 |
+
decay=0.9,
|
| 69 |
+
epsilon=0.001,
|
| 70 |
+
momentum=0.9):
|
| 71 |
+
"""Build optimizer."""
|
| 72 |
+
if optimizer_name == 'sgd':
|
| 73 |
+
logging.info('Using SGD optimizer')
|
| 74 |
+
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
|
| 75 |
+
elif optimizer_name == 'momentum':
|
| 76 |
+
logging.info('Using Momentum optimizer')
|
| 77 |
+
optimizer = tf.train.MomentumOptimizer(
|
| 78 |
+
learning_rate=learning_rate, momentum=momentum)
|
| 79 |
+
elif optimizer_name == 'rmsprop':
|
| 80 |
+
logging.info('Using RMSProp optimizer')
|
| 81 |
+
optimizer = tf.train.RMSPropOptimizer(learning_rate, decay, momentum,
|
| 82 |
+
epsilon)
|
| 83 |
+
else:
|
| 84 |
+
logging.fatal('Unknown optimizer: %s', optimizer_name)
|
| 85 |
+
|
| 86 |
+
return optimizer
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class TpuBatchNormalization(tf.layers.BatchNormalization):
|
| 90 |
+
# class TpuBatchNormalization(tf.layers.BatchNormalization):
|
| 91 |
+
"""Cross replica batch normalization."""
|
| 92 |
+
|
| 93 |
+
def __init__(self, fused=False, **kwargs):
|
| 94 |
+
if fused in (True, None):
|
| 95 |
+
raise ValueError('TpuBatchNormalization does not support fused=True.')
|
| 96 |
+
super(TpuBatchNormalization, self).__init__(fused=fused, **kwargs)
|
| 97 |
+
|
| 98 |
+
def _cross_replica_average(self, t, num_shards_per_group):
|
| 99 |
+
"""Calculates the average value of input tensor across TPU replicas."""
|
| 100 |
+
num_shards = tpu_function.get_tpu_context().number_of_shards
|
| 101 |
+
group_assignment = None
|
| 102 |
+
if num_shards_per_group > 1:
|
| 103 |
+
if num_shards % num_shards_per_group != 0:
|
| 104 |
+
raise ValueError('num_shards: %d mod shards_per_group: %d, should be 0'
|
| 105 |
+
% (num_shards, num_shards_per_group))
|
| 106 |
+
num_groups = num_shards // num_shards_per_group
|
| 107 |
+
group_assignment = [[
|
| 108 |
+
x for x in range(num_shards) if x // num_shards_per_group == y
|
| 109 |
+
] for y in range(num_groups)]
|
| 110 |
+
return tf.tpu.cross_replica_sum(t, group_assignment) / tf.cast(
|
| 111 |
+
num_shards_per_group, t.dtype)
|
| 112 |
+
|
| 113 |
+
def _moments(self, inputs, reduction_axes, keep_dims):
|
| 114 |
+
"""Compute the mean and variance: it overrides the original _moments."""
|
| 115 |
+
shard_mean, shard_variance = super(TpuBatchNormalization, self)._moments(
|
| 116 |
+
inputs, reduction_axes, keep_dims=keep_dims)
|
| 117 |
+
|
| 118 |
+
num_shards = tpu_function.get_tpu_context().number_of_shards or 1
|
| 119 |
+
if num_shards <= 8: # Skip cross_replica for 2x2 or smaller slices.
|
| 120 |
+
num_shards_per_group = 1
|
| 121 |
+
else:
|
| 122 |
+
num_shards_per_group = max(8, num_shards // 8)
|
| 123 |
+
logging.info('TpuBatchNormalization with num_shards_per_group %s',
|
| 124 |
+
num_shards_per_group)
|
| 125 |
+
if num_shards_per_group > 1:
|
| 126 |
+
# Compute variance using: Var[X]= E[X^2] - E[X]^2.
|
| 127 |
+
shard_square_of_mean = tf.math.square(shard_mean)
|
| 128 |
+
shard_mean_of_square = shard_variance + shard_square_of_mean
|
| 129 |
+
group_mean = self._cross_replica_average(
|
| 130 |
+
shard_mean, num_shards_per_group)
|
| 131 |
+
group_mean_of_square = self._cross_replica_average(
|
| 132 |
+
shard_mean_of_square, num_shards_per_group)
|
| 133 |
+
group_variance = group_mean_of_square - tf.math.square(group_mean)
|
| 134 |
+
return (group_mean, group_variance)
|
| 135 |
+
else:
|
| 136 |
+
return (shard_mean, shard_variance)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class BatchNormalization(tf.layers.BatchNormalization):
|
| 140 |
+
"""Fixed default name of BatchNormalization to match TpuBatchNormalization."""
|
| 141 |
+
|
| 142 |
+
def __init__(self, name='tpu_batch_normalization', **kwargs):
|
| 143 |
+
super(BatchNormalization, self).__init__(name=name, **kwargs)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def drop_connect(inputs, is_training, survival_prob):
|
| 147 |
+
"""Drop the entire conv with given survival probability."""
|
| 148 |
+
# "Deep Networks with Stochastic Depth", https://arxiv.org/pdf/1603.09382.pdf
|
| 149 |
+
if not is_training:
|
| 150 |
+
return inputs
|
| 151 |
+
|
| 152 |
+
# Compute tensor.
|
| 153 |
+
batch_size = tf.shape(inputs)[0]
|
| 154 |
+
random_tensor = survival_prob
|
| 155 |
+
random_tensor += tf.random_uniform([batch_size, 1, 1, 1], dtype=inputs.dtype)
|
| 156 |
+
binary_tensor = tf.floor(random_tensor)
|
| 157 |
+
# Unlike conventional way that multiply survival_prob at test time, here we
|
| 158 |
+
# divide survival_prob at training time, such that no addition compute is
|
| 159 |
+
# needed at test time.
|
| 160 |
+
output = tf.div(inputs, survival_prob) * binary_tensor
|
| 161 |
+
return output
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def archive_ckpt(ckpt_eval, ckpt_objective, ckpt_path):
|
| 165 |
+
"""Archive a checkpoint if the metric is better."""
|
| 166 |
+
ckpt_dir, ckpt_name = os.path.split(ckpt_path)
|
| 167 |
+
|
| 168 |
+
saved_objective_path = os.path.join(ckpt_dir, 'best_objective.txt')
|
| 169 |
+
saved_objective = float('-inf')
|
| 170 |
+
if tf.gfile.Exists(saved_objective_path):
|
| 171 |
+
with tf.gfile.GFile(saved_objective_path, 'r') as f:
|
| 172 |
+
saved_objective = float(f.read())
|
| 173 |
+
if saved_objective > ckpt_objective:
|
| 174 |
+
logging.info('Ckpt %s is worse than %s', ckpt_objective, saved_objective)
|
| 175 |
+
return False
|
| 176 |
+
|
| 177 |
+
filenames = tf.gfile.Glob(ckpt_path + '.*')
|
| 178 |
+
if filenames is None:
|
| 179 |
+
logging.info('No files to copy for checkpoint %s', ckpt_path)
|
| 180 |
+
return False
|
| 181 |
+
|
| 182 |
+
# Clear the old folder.
|
| 183 |
+
dst_dir = os.path.join(ckpt_dir, 'archive')
|
| 184 |
+
if tf.gfile.Exists(dst_dir):
|
| 185 |
+
tf.gfile.DeleteRecursively(dst_dir)
|
| 186 |
+
tf.gfile.MakeDirs(dst_dir)
|
| 187 |
+
|
| 188 |
+
# Write checkpoints.
|
| 189 |
+
for f in filenames:
|
| 190 |
+
dest = os.path.join(dst_dir, os.path.basename(f))
|
| 191 |
+
tf.gfile.Copy(f, dest, overwrite=True)
|
| 192 |
+
ckpt_state = tf.train.generate_checkpoint_state_proto(
|
| 193 |
+
dst_dir,
|
| 194 |
+
model_checkpoint_path=ckpt_name,
|
| 195 |
+
all_model_checkpoint_paths=[ckpt_name])
|
| 196 |
+
with tf.gfile.GFile(os.path.join(dst_dir, 'checkpoint'), 'w') as f:
|
| 197 |
+
f.write(str(ckpt_state))
|
| 198 |
+
with tf.gfile.GFile(os.path.join(dst_dir, 'best_eval.txt'), 'w') as f:
|
| 199 |
+
f.write('%s' % ckpt_eval)
|
| 200 |
+
|
| 201 |
+
# Update the best objective.
|
| 202 |
+
with tf.gfile.GFile(saved_objective_path, 'w') as f:
|
| 203 |
+
f.write('%f' % ckpt_objective)
|
| 204 |
+
|
| 205 |
+
logging.info('Copying checkpoint %s to %s', ckpt_path, dst_dir)
|
| 206 |
+
return True
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def get_ema_vars():
|
| 210 |
+
"""Get all exponential moving average (ema) variables."""
|
| 211 |
+
ema_vars = tf.trainable_variables() + tf.get_collection('moving_vars')
|
| 212 |
+
for v in tf.global_variables():
|
| 213 |
+
# We maintain mva for batch norm moving mean and variance as well.
|
| 214 |
+
if 'moving_mean' in v.name or 'moving_variance' in v.name:
|
| 215 |
+
ema_vars.append(v)
|
| 216 |
+
return list(set(ema_vars))
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class DepthwiseConv2D(tf.keras.layers.DepthwiseConv2D, tf.layers.Layer):
|
| 220 |
+
"""Wrap keras DepthwiseConv2D to tf.layers."""
|
| 221 |
+
|
| 222 |
+
pass
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class EvalCkptDriver(object):
|
| 226 |
+
"""A driver for running eval inference.
|
| 227 |
+
|
| 228 |
+
Attributes:
|
| 229 |
+
model_name: str. Model name to eval.
|
| 230 |
+
batch_size: int. Eval batch size.
|
| 231 |
+
image_size: int. Input image size, determined by model name.
|
| 232 |
+
num_classes: int. Number of classes, default to 1000 for ImageNet.
|
| 233 |
+
include_background_label: whether to include extra background label.
|
| 234 |
+
"""
|
| 235 |
+
|
| 236 |
+
def __init__(self,
|
| 237 |
+
model_name,
|
| 238 |
+
batch_size=1,
|
| 239 |
+
image_size=224,
|
| 240 |
+
num_classes=1000,
|
| 241 |
+
include_background_label=False):
|
| 242 |
+
"""Initialize internal variables."""
|
| 243 |
+
self.model_name = model_name
|
| 244 |
+
self.batch_size = batch_size
|
| 245 |
+
self.num_classes = num_classes
|
| 246 |
+
self.include_background_label = include_background_label
|
| 247 |
+
self.image_size = image_size
|
| 248 |
+
|
| 249 |
+
def restore_model(self, sess, ckpt_dir, enable_ema=True, export_ckpt=None):
|
| 250 |
+
"""Restore variables from checkpoint dir."""
|
| 251 |
+
sess.run(tf.global_variables_initializer())
|
| 252 |
+
checkpoint = tf.train.latest_checkpoint(ckpt_dir)
|
| 253 |
+
if enable_ema:
|
| 254 |
+
ema = tf.train.ExponentialMovingAverage(decay=0.0)
|
| 255 |
+
ema_vars = get_ema_vars()
|
| 256 |
+
var_dict = ema.variables_to_restore(ema_vars)
|
| 257 |
+
ema_assign_op = ema.apply(ema_vars)
|
| 258 |
+
else:
|
| 259 |
+
var_dict = get_ema_vars()
|
| 260 |
+
ema_assign_op = None
|
| 261 |
+
|
| 262 |
+
tf.train.get_or_create_global_step()
|
| 263 |
+
sess.run(tf.global_variables_initializer())
|
| 264 |
+
saver = tf.train.Saver(var_dict, max_to_keep=1)
|
| 265 |
+
saver.restore(sess, checkpoint)
|
| 266 |
+
|
| 267 |
+
if export_ckpt:
|
| 268 |
+
if ema_assign_op is not None:
|
| 269 |
+
sess.run(ema_assign_op)
|
| 270 |
+
saver = tf.train.Saver(max_to_keep=1, save_relative_paths=True)
|
| 271 |
+
saver.save(sess, export_ckpt)
|
| 272 |
+
|
| 273 |
+
def build_model(self, features, is_training):
|
| 274 |
+
"""Build model with input features."""
|
| 275 |
+
del features, is_training
|
| 276 |
+
raise ValueError('Must be implemented by subclasses.')
|
| 277 |
+
|
| 278 |
+
def get_preprocess_fn(self):
|
| 279 |
+
raise ValueError('Must be implemented by subclsses.')
|
| 280 |
+
|
| 281 |
+
def build_dataset(self, filenames, labels, is_training):
|
| 282 |
+
"""Build input dataset."""
|
| 283 |
+
batch_drop_remainder = False
|
| 284 |
+
if 'condconv' in self.model_name and not is_training:
|
| 285 |
+
# CondConv layers can only be called with known batch dimension. Thus, we
|
| 286 |
+
# must drop all remaining examples that do not make up one full batch.
|
| 287 |
+
# To ensure all examples are evaluated, use a batch size that evenly
|
| 288 |
+
# divides the number of files.
|
| 289 |
+
batch_drop_remainder = True
|
| 290 |
+
num_files = len(filenames)
|
| 291 |
+
if num_files % self.batch_size != 0:
|
| 292 |
+
tf.logging.warn('Remaining examples in last batch are not being '
|
| 293 |
+
'evaluated.')
|
| 294 |
+
filenames = tf.constant(filenames)
|
| 295 |
+
labels = tf.constant(labels)
|
| 296 |
+
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
|
| 297 |
+
|
| 298 |
+
def _parse_function(filename, label):
|
| 299 |
+
image_string = tf.read_file(filename)
|
| 300 |
+
preprocess_fn = self.get_preprocess_fn()
|
| 301 |
+
image_decoded = preprocess_fn(
|
| 302 |
+
image_string, is_training, image_size=self.image_size)
|
| 303 |
+
image = tf.cast(image_decoded, tf.float32)
|
| 304 |
+
return image, label
|
| 305 |
+
|
| 306 |
+
dataset = dataset.map(_parse_function)
|
| 307 |
+
dataset = dataset.batch(self.batch_size,
|
| 308 |
+
drop_remainder=batch_drop_remainder)
|
| 309 |
+
|
| 310 |
+
iterator = dataset.make_one_shot_iterator()
|
| 311 |
+
images, labels = iterator.get_next()
|
| 312 |
+
return images, labels
|
| 313 |
+
|
| 314 |
+
def run_inference(self,
|
| 315 |
+
ckpt_dir,
|
| 316 |
+
image_files,
|
| 317 |
+
labels,
|
| 318 |
+
enable_ema=True,
|
| 319 |
+
export_ckpt=None):
|
| 320 |
+
"""Build and run inference on the target images and labels."""
|
| 321 |
+
label_offset = 1 if self.include_background_label else 0
|
| 322 |
+
with tf.Graph().as_default(), tf.Session() as sess:
|
| 323 |
+
images, labels = self.build_dataset(image_files, labels, False)
|
| 324 |
+
probs = self.build_model(images, is_training=False)
|
| 325 |
+
if isinstance(probs, tuple):
|
| 326 |
+
probs = probs[0]
|
| 327 |
+
|
| 328 |
+
self.restore_model(sess, ckpt_dir, enable_ema, export_ckpt)
|
| 329 |
+
|
| 330 |
+
prediction_idx = []
|
| 331 |
+
prediction_prob = []
|
| 332 |
+
for _ in range(len(image_files) // self.batch_size):
|
| 333 |
+
out_probs = sess.run(probs)
|
| 334 |
+
idx = np.argsort(out_probs)[::-1]
|
| 335 |
+
prediction_idx.append(idx[:5] - label_offset)
|
| 336 |
+
prediction_prob.append([out_probs[pid] for pid in idx[:5]])
|
| 337 |
+
|
| 338 |
+
# Return the top 5 predictions (idx and prob) for each image.
|
| 339 |
+
return prediction_idx, prediction_prob
|
| 340 |
+
|
| 341 |
+
def eval_example_images(self,
|
| 342 |
+
ckpt_dir,
|
| 343 |
+
image_files,
|
| 344 |
+
labels_map_file,
|
| 345 |
+
enable_ema=True,
|
| 346 |
+
export_ckpt=None):
|
| 347 |
+
"""Eval a list of example images.
|
| 348 |
+
|
| 349 |
+
Args:
|
| 350 |
+
ckpt_dir: str. Checkpoint directory path.
|
| 351 |
+
image_files: List[str]. A list of image file paths.
|
| 352 |
+
labels_map_file: str. The labels map file path.
|
| 353 |
+
enable_ema: enable expotential moving average.
|
| 354 |
+
export_ckpt: export ckpt folder.
|
| 355 |
+
|
| 356 |
+
Returns:
|
| 357 |
+
A tuple (pred_idx, and pred_prob), where pred_idx is the top 5 prediction
|
| 358 |
+
index and pred_prob is the top 5 prediction probability.
|
| 359 |
+
"""
|
| 360 |
+
classes = json.loads(tf.gfile.Open(labels_map_file).read())
|
| 361 |
+
pred_idx, pred_prob = self.run_inference(
|
| 362 |
+
ckpt_dir, image_files, [0] * len(image_files), enable_ema, export_ckpt)
|
| 363 |
+
for i in range(len(image_files)):
|
| 364 |
+
print('predicted class for image {}: '.format(image_files[i]))
|
| 365 |
+
for j, idx in enumerate(pred_idx[i]):
|
| 366 |
+
print(' -> top_{} ({:4.2f}%): {} '.format(j, pred_prob[i][j] * 100,
|
| 367 |
+
classes[str(idx)]))
|
| 368 |
+
return pred_idx, pred_prob
|
| 369 |
+
|
| 370 |
+
def eval_imagenet(self, ckpt_dir, imagenet_eval_glob,
|
| 371 |
+
imagenet_eval_label, num_images, enable_ema, export_ckpt):
|
| 372 |
+
"""Eval ImageNet images and report top1/top5 accuracy.
|
| 373 |
+
|
| 374 |
+
Args:
|
| 375 |
+
ckpt_dir: str. Checkpoint directory path.
|
| 376 |
+
imagenet_eval_glob: str. File path glob for all eval images.
|
| 377 |
+
imagenet_eval_label: str. File path for eval label.
|
| 378 |
+
num_images: int. Number of images to eval: -1 means eval the whole
|
| 379 |
+
dataset.
|
| 380 |
+
enable_ema: enable expotential moving average.
|
| 381 |
+
export_ckpt: export checkpoint folder.
|
| 382 |
+
|
| 383 |
+
Returns:
|
| 384 |
+
A tuple (top1, top5) for top1 and top5 accuracy.
|
| 385 |
+
"""
|
| 386 |
+
imagenet_val_labels = [int(i) for i in tf.gfile.GFile(imagenet_eval_label)]
|
| 387 |
+
imagenet_filenames = sorted(tf.gfile.Glob(imagenet_eval_glob))
|
| 388 |
+
if num_images < 0:
|
| 389 |
+
num_images = len(imagenet_filenames)
|
| 390 |
+
image_files = imagenet_filenames[:num_images]
|
| 391 |
+
labels = imagenet_val_labels[:num_images]
|
| 392 |
+
|
| 393 |
+
pred_idx, _ = self.run_inference(
|
| 394 |
+
ckpt_dir, image_files, labels, enable_ema, export_ckpt)
|
| 395 |
+
top1_cnt, top5_cnt = 0.0, 0.0
|
| 396 |
+
for i, label in enumerate(labels):
|
| 397 |
+
top1_cnt += label in pred_idx[i][:1]
|
| 398 |
+
top5_cnt += label in pred_idx[i][:5]
|
| 399 |
+
if i % 100 == 0:
|
| 400 |
+
print('Step {}: top1_acc = {:4.2f}% top5_acc = {:4.2f}%'.format(
|
| 401 |
+
i, 100 * top1_cnt / (i + 1), 100 * top5_cnt / (i + 1)))
|
| 402 |
+
sys.stdout.flush()
|
| 403 |
+
top1, top5 = 100 * top1_cnt / num_images, 100 * top5_cnt / num_images
|
| 404 |
+
print('Final: top1_acc = {:4.2f}% top5_acc = {:4.2f}%'.format(top1, top5))
|
| 405 |
+
return top1, top5
|
extract_tracks_from_videos.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import yaml
|
| 4 |
+
import random
|
| 5 |
+
import pickle
|
| 6 |
+
import tqdm
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from generate_aligned_tracks import ALIGNED_TRACKS_FILE_NAME
|
| 12 |
+
|
| 13 |
+
SEED = 0xDEADFACE
|
| 14 |
+
TRACK_LENGTH = 50
|
| 15 |
+
DETECTOR_STEP = 6
|
| 16 |
+
BOX_MULT = 1.5
|
| 17 |
+
|
| 18 |
+
TRACKS_ROOT = 'tracks'
|
| 19 |
+
BOXES_FILE_NAME = 'boxes.float32'
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def main():
|
| 23 |
+
parser = argparse.ArgumentParser(description='Extracts tracks from videos')
|
| 24 |
+
parser.add_argument('--num_parts', type=int, default=1, help='Number of parts')
|
| 25 |
+
parser.add_argument('--part', type=int, default=0, help='Part')
|
| 26 |
+
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
|
| 29 |
+
with open('config.yaml', 'r') as f:
|
| 30 |
+
config = yaml.load(f)
|
| 31 |
+
|
| 32 |
+
with open(os.path.join(config['ARTIFACTS_PATH'], ALIGNED_TRACKS_FILE_NAME), 'rb') as f:
|
| 33 |
+
aligned_tracks = pickle.load(f)
|
| 34 |
+
|
| 35 |
+
part_size = len(aligned_tracks) // args.num_parts + 1
|
| 36 |
+
assert part_size * args.num_parts >= len(aligned_tracks)
|
| 37 |
+
part_start = part_size * args.part
|
| 38 |
+
part_end = min(part_start + part_size, len(aligned_tracks))
|
| 39 |
+
print('Part {} ({}, {})'.format(args.part, part_start, part_end))
|
| 40 |
+
|
| 41 |
+
random.seed(SEED)
|
| 42 |
+
for real_video, fake_video, aligned_track in tqdm.tqdm(aligned_tracks[part_start:part_end]):
|
| 43 |
+
if len(aligned_track) < TRACK_LENGTH // DETECTOR_STEP:
|
| 44 |
+
continue
|
| 45 |
+
real_boxes = [item[1] for item in aligned_track]
|
| 46 |
+
fake_boxes = [item[2] for item in aligned_track]
|
| 47 |
+
start_idx = random.randint(0, len(aligned_track) - TRACK_LENGTH // DETECTOR_STEP)
|
| 48 |
+
start_frame = aligned_track[start_idx][0] * DETECTOR_STEP
|
| 49 |
+
middle_idx = start_idx + TRACK_LENGTH // DETECTOR_STEP // 2
|
| 50 |
+
|
| 51 |
+
if random.choice([False, True]):
|
| 52 |
+
xmin, ymin, xmax, ymax = real_boxes[middle_idx]
|
| 53 |
+
else:
|
| 54 |
+
xmin, ymin, xmax, ymax = fake_boxes[middle_idx]
|
| 55 |
+
|
| 56 |
+
width = xmax - xmin
|
| 57 |
+
height = ymax - ymin
|
| 58 |
+
xcenter = xmin + width / 2
|
| 59 |
+
ycenter = ymin + height / 2
|
| 60 |
+
width = width * BOX_MULT
|
| 61 |
+
height = height * BOX_MULT
|
| 62 |
+
xmin = xcenter - width / 2
|
| 63 |
+
ymin = ycenter - height / 2
|
| 64 |
+
xmax = xmin + width
|
| 65 |
+
ymax = ymin + height
|
| 66 |
+
|
| 67 |
+
for video, boxes in [(real_video, real_boxes), (fake_video, fake_boxes)]:
|
| 68 |
+
capture = cv2.VideoCapture(os.path.join(config['DFDC_DATA_PATH'], video))
|
| 69 |
+
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 70 |
+
if frame_count == 0:
|
| 71 |
+
continue
|
| 72 |
+
frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 73 |
+
frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 74 |
+
|
| 75 |
+
xmin = max(int(xmin), 0)
|
| 76 |
+
xmax = min(int(xmax), frame_width)
|
| 77 |
+
ymin = max(int(ymin), 0)
|
| 78 |
+
ymax = min(int(ymax), frame_height)
|
| 79 |
+
|
| 80 |
+
dst_root = os.path.join(config['ARTIFACTS_PATH'], TRACKS_ROOT,
|
| 81 |
+
video + '_{}_{}_{}'.format(start_frame, xmin, ymin))
|
| 82 |
+
if os.path.exists(dst_root):
|
| 83 |
+
continue
|
| 84 |
+
os.makedirs(dst_root)
|
| 85 |
+
for i in range(start_frame + TRACK_LENGTH):
|
| 86 |
+
capture.grab()
|
| 87 |
+
if i < start_frame:
|
| 88 |
+
continue
|
| 89 |
+
ret, frame = capture.retrieve()
|
| 90 |
+
if not ret:
|
| 91 |
+
continue
|
| 92 |
+
face = frame[ymin:ymax, xmin:xmax]
|
| 93 |
+
dst_path = os.path.join(dst_root, '{}.png'.format(i - start_frame))
|
| 94 |
+
cv2.imwrite(dst_path, face)
|
| 95 |
+
|
| 96 |
+
boxes = np.array(boxes, dtype=np.float32)
|
| 97 |
+
boxes[:, 0] -= xmin
|
| 98 |
+
boxes[:, 1] -= ymin
|
| 99 |
+
boxes[:, 2] -= xmin
|
| 100 |
+
boxes[:, 3] -= ymin
|
| 101 |
+
boxes.tofile(os.path.join(dst_root, BOXES_FILE_NAME))
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
if __name__ == '__main__':
|
| 105 |
+
main()
|
generate_aligned_tracks.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
import yaml
|
| 4 |
+
import json
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
import tqdm
|
| 7 |
+
import pickle
|
| 8 |
+
|
| 9 |
+
from tracker.utils import iou
|
| 10 |
+
|
| 11 |
+
from generate_tracks import TRACKS_FILE_NAME
|
| 12 |
+
|
| 13 |
+
MIN_TRACK_LENGTH = 5
|
| 14 |
+
IOU_THRESHOLD = 0.5
|
| 15 |
+
METADATA_FILE_NAME = 'metadata.json'
|
| 16 |
+
|
| 17 |
+
ALIGNED_TRACKS_FILE_NAME = 'aligned_tracks.pkl'
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_track(tracks, min_track_length):
|
| 21 |
+
good_tracks = [track for track in tracks if len(track) >= min_track_length]
|
| 22 |
+
if len(good_tracks) == 1:
|
| 23 |
+
return good_tracks[0]
|
| 24 |
+
else:
|
| 25 |
+
return None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def main():
|
| 29 |
+
with open('config.yaml', 'r') as f:
|
| 30 |
+
config = yaml.load(f)
|
| 31 |
+
|
| 32 |
+
video_to_meta = {}
|
| 33 |
+
|
| 34 |
+
for path in glob.iglob(os.path.join(config['DFDC_DATA_PATH'], '**', METADATA_FILE_NAME), recursive=True):
|
| 35 |
+
root = os.path.basename(os.path.dirname(path))
|
| 36 |
+
with open(path, 'r') as f:
|
| 37 |
+
for video, meta in json.load(f).items():
|
| 38 |
+
video_to_meta[os.path.join(root, video)] = meta
|
| 39 |
+
|
| 40 |
+
real_video_to_fake_videos = defaultdict(list)
|
| 41 |
+
for video in video_to_meta:
|
| 42 |
+
root = os.path.dirname(video)
|
| 43 |
+
meta = video_to_meta[video]
|
| 44 |
+
if meta['label'] == 'FAKE':
|
| 45 |
+
original_video = os.path.join(root, meta['original'])
|
| 46 |
+
real_video_to_fake_videos[original_video].append(video)
|
| 47 |
+
|
| 48 |
+
print('Total number of real videos: {}'.format(len(real_video_to_fake_videos)))
|
| 49 |
+
print('Total number of fake videos: {}'.format(sum([len(fake_videos) for fake_videos in real_video_to_fake_videos.items()])))
|
| 50 |
+
|
| 51 |
+
with open(os.path.join(config['ARTIFACTS_PATH'], TRACKS_FILE_NAME), 'rb') as f:
|
| 52 |
+
video_to_tracks = pickle.load(f)
|
| 53 |
+
|
| 54 |
+
real_fake_aligned_tracks = []
|
| 55 |
+
real_videos = sorted(real_video_to_fake_videos)
|
| 56 |
+
for real_video in tqdm.tqdm(real_videos):
|
| 57 |
+
if real_video not in video_to_tracks:
|
| 58 |
+
continue
|
| 59 |
+
real_tracks = [track for track in video_to_tracks[real_video] if len(track) >= MIN_TRACK_LENGTH]
|
| 60 |
+
|
| 61 |
+
for fake_video in real_video_to_fake_videos[real_video]:
|
| 62 |
+
if fake_video not in video_to_tracks:
|
| 63 |
+
continue
|
| 64 |
+
fake_tracks = [track for track in video_to_tracks[fake_video] if len(track) >= MIN_TRACK_LENGTH]
|
| 65 |
+
|
| 66 |
+
for real_track in real_tracks:
|
| 67 |
+
real_frame_idx_to_bbox = {}
|
| 68 |
+
for real_frame_idx, real_bbox in real_track:
|
| 69 |
+
real_frame_idx_to_bbox[real_frame_idx] = real_bbox
|
| 70 |
+
|
| 71 |
+
for fake_track in fake_tracks:
|
| 72 |
+
fake_frame_idx_to_bbox = {}
|
| 73 |
+
ious = []
|
| 74 |
+
for fake_frame_idx, fake_bbox in fake_track:
|
| 75 |
+
fake_frame_idx_to_bbox[fake_frame_idx] = fake_bbox
|
| 76 |
+
if fake_frame_idx in real_frame_idx_to_bbox:
|
| 77 |
+
real_bbox = real_frame_idx_to_bbox[fake_frame_idx]
|
| 78 |
+
ious.append(iou(real_bbox, fake_bbox))
|
| 79 |
+
if len(ious) > 0 and min(ious) > IOU_THRESHOLD:
|
| 80 |
+
start_frame_idx = max(min(real_frame_idx_to_bbox), min(fake_frame_idx_to_bbox))
|
| 81 |
+
end_frame_idx = min(max(real_frame_idx_to_bbox), max(fake_frame_idx_to_bbox)) + 1
|
| 82 |
+
assert start_frame_idx < end_frame_idx
|
| 83 |
+
real_fake_aligned_track = []
|
| 84 |
+
for frame_idx in range(start_frame_idx, end_frame_idx):
|
| 85 |
+
real_bbox = real_frame_idx_to_bbox[frame_idx]
|
| 86 |
+
fake_bbox = fake_frame_idx_to_bbox[frame_idx]
|
| 87 |
+
assert iou(real_bbox, fake_bbox) > IOU_THRESHOLD
|
| 88 |
+
real_fake_aligned_track.append((frame_idx, real_bbox, fake_bbox))
|
| 89 |
+
real_fake_aligned_tracks.append((real_video, fake_video, real_fake_aligned_track))
|
| 90 |
+
break
|
| 91 |
+
|
| 92 |
+
print('Total number of tracks: {}'.format(len(real_fake_aligned_tracks)))
|
| 93 |
+
|
| 94 |
+
with open(os.path.join(config['ARTIFACTS_PATH'], ALIGNED_TRACKS_FILE_NAME), 'wb') as f:
|
| 95 |
+
pickle.dump(real_fake_aligned_tracks, f)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
if __name__ == '__main__':
|
| 99 |
+
main()
|
generate_track_pairs.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
import glob
|
| 6 |
+
|
| 7 |
+
from generate_aligned_tracks import METADATA_FILE_NAME
|
| 8 |
+
from extract_tracks_from_videos import TRACKS_ROOT
|
| 9 |
+
|
| 10 |
+
TRACK_PAIRS_FILE_NAME = 'track_pairs.txt'
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def main():
|
| 14 |
+
with open('config.yaml', 'r') as f:
|
| 15 |
+
config = yaml.load(f)
|
| 16 |
+
|
| 17 |
+
video_to_tracks = defaultdict(list)
|
| 18 |
+
|
| 19 |
+
for path in glob.iglob(os.path.join(config['ARTIFACTS_PATH'], TRACKS_ROOT, 'dfdc_train_part_*', '*.mp4_*')):
|
| 20 |
+
parts = path.split('/')
|
| 21 |
+
rel_path = '/'.join(parts[-2:])
|
| 22 |
+
video = '_'.join(rel_path.split('_')[:-3])
|
| 23 |
+
video_to_tracks[video].append(rel_path)
|
| 24 |
+
|
| 25 |
+
video_to_meta = {}
|
| 26 |
+
|
| 27 |
+
for path in glob.iglob(os.path.join(config['DFDC_DATA_PATH'], '**', METADATA_FILE_NAME), recursive=True):
|
| 28 |
+
root = os.path.basename(os.path.dirname(path))
|
| 29 |
+
with open(path, 'r') as f:
|
| 30 |
+
for video, meta in json.load(f).items():
|
| 31 |
+
video_to_meta[os.path.join(root, video)] = meta
|
| 32 |
+
|
| 33 |
+
fake_video_to_real_video = {}
|
| 34 |
+
for video in video_to_meta:
|
| 35 |
+
root = os.path.dirname(video)
|
| 36 |
+
meta = video_to_meta[video]
|
| 37 |
+
if meta['label'] == 'FAKE':
|
| 38 |
+
original_video = os.path.join(root, meta['original'])
|
| 39 |
+
fake_video_to_real_video[video] = original_video
|
| 40 |
+
|
| 41 |
+
print('Total number of fake videos: {}'.format(len(fake_video_to_real_video)))
|
| 42 |
+
|
| 43 |
+
track_pairs = []
|
| 44 |
+
|
| 45 |
+
fake_videos = sorted(fake_video_to_real_video)
|
| 46 |
+
for fake_video in fake_videos:
|
| 47 |
+
real_video = fake_video_to_real_video[fake_video]
|
| 48 |
+
fake_tracks = video_to_tracks[fake_video]
|
| 49 |
+
real_tracks = video_to_tracks[real_video]
|
| 50 |
+
|
| 51 |
+
for fake_track in fake_tracks:
|
| 52 |
+
if not os.path.exists(os.path.join(config['ARTIFACTS_PATH'], TRACKS_ROOT, fake_track, '0.png')):
|
| 53 |
+
continue
|
| 54 |
+
suffix = fake_track[len(fake_video):]
|
| 55 |
+
for real_track in real_tracks:
|
| 56 |
+
if not os.path.exists(os.path.join(config['ARTIFACTS_PATH'], TRACKS_ROOT, real_track, '0.png')):
|
| 57 |
+
continue
|
| 58 |
+
if real_track.endswith(suffix):
|
| 59 |
+
track_pairs.append((real_track, fake_track))
|
| 60 |
+
break
|
| 61 |
+
|
| 62 |
+
print('Total number of track pairs: {}'.format(len(track_pairs)))
|
| 63 |
+
|
| 64 |
+
with open(os.path.join(config['ARTIFACTS_PATH'], TRACK_PAIRS_FILE_NAME), 'w') as f:
|
| 65 |
+
for real_track, fake_track in track_pairs:
|
| 66 |
+
f.write('{},{}\n'.format(real_track, fake_track))
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
if __name__ == '__main__':
|
| 70 |
+
main()
|
generate_tracks.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import yaml
|
| 3 |
+
import tqdm
|
| 4 |
+
import glob
|
| 5 |
+
import pickle
|
| 6 |
+
|
| 7 |
+
from tracker.iou_tracker import track_iou
|
| 8 |
+
from detect_faces_on_videos import DETECTIONS_FILE_NAME, DETECTIONS_ROOT
|
| 9 |
+
|
| 10 |
+
SIGMA_L = 0.3
|
| 11 |
+
SIGMA_H = 0.9
|
| 12 |
+
SIGMA_IOU = 0.3
|
| 13 |
+
T_MIN = 1
|
| 14 |
+
|
| 15 |
+
TRACKS_FILE_NAME = 'tracks.pkl'
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_tracks(detections):
|
| 19 |
+
if len(detections) == 0:
|
| 20 |
+
return []
|
| 21 |
+
|
| 22 |
+
converted_detections = []
|
| 23 |
+
for i, detections_per_frame in enumerate(detections):
|
| 24 |
+
converted_detections_per_frame = []
|
| 25 |
+
for j, (bbox, score) in enumerate(zip(detections_per_frame['boxes'], detections_per_frame['scores'])):
|
| 26 |
+
bbox = tuple(bbox.tolist())
|
| 27 |
+
converted_detections_per_frame.append({'bbox': bbox, 'score': score})
|
| 28 |
+
converted_detections.append(converted_detections_per_frame)
|
| 29 |
+
|
| 30 |
+
tracks = track_iou(converted_detections, SIGMA_L, SIGMA_H, SIGMA_IOU, T_MIN)
|
| 31 |
+
tracks_converted = []
|
| 32 |
+
for track in tracks:
|
| 33 |
+
track_converted = []
|
| 34 |
+
start_frame = track['start_frame'] - 1
|
| 35 |
+
for i, bbox in enumerate(track['bboxes']):
|
| 36 |
+
track_converted.append((start_frame + i, bbox))
|
| 37 |
+
tracks_converted.append(track_converted)
|
| 38 |
+
|
| 39 |
+
return tracks_converted
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def main():
|
| 43 |
+
with open('config.yaml', 'r') as f:
|
| 44 |
+
config = yaml.load(f)
|
| 45 |
+
|
| 46 |
+
root_dir = os.path.join(config['ARTIFACTS_PATH'], DETECTIONS_ROOT)
|
| 47 |
+
detections_content = []
|
| 48 |
+
for path in glob.iglob(os.path.join(root_dir, '**', DETECTIONS_FILE_NAME), recursive=True):
|
| 49 |
+
rel_path = path[len(root_dir) + 1:]
|
| 50 |
+
detections_content.append(rel_path)
|
| 51 |
+
|
| 52 |
+
detections_content = sorted(detections_content)
|
| 53 |
+
print('Total number of videos: {}'.format(len(detections_content)))
|
| 54 |
+
|
| 55 |
+
video_to_tracks = {}
|
| 56 |
+
for rel_path in tqdm.tqdm(detections_content):
|
| 57 |
+
video = os.path.dirname(rel_path)
|
| 58 |
+
with open(os.path.join(root_dir, rel_path), 'rb') as f:
|
| 59 |
+
detections = pickle.load(f)
|
| 60 |
+
video_to_tracks[video] = get_tracks(detections)
|
| 61 |
+
|
| 62 |
+
track_count = sum([len(tracks) for tracks in video_to_tracks.values()])
|
| 63 |
+
print('Total number of tracks: {}'.format(track_count))
|
| 64 |
+
|
| 65 |
+
with open(os.path.join(config['ARTIFACTS_PATH'], TRACKS_FILE_NAME), 'wb') as f:
|
| 66 |
+
pickle.dump(video_to_tracks, f)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
if __name__ == '__main__':
|
| 70 |
+
main()
|
images/augmented_mixup.jpg
ADDED
|
Git LFS Details
|
images/clip_example.jpg
ADDED
|
images/first_and_second_model_inputs.jpg
ADDED
|
images/mixup_example.jpg
ADDED
|
Git LFS Details
|
images/pred_transform.jpg
ADDED
|
images/third_model_input.jpg
ADDED
|
models/.gitkeep
ADDED
|
File without changes
|
predict.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import yaml
|
| 3 |
+
import glob
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import cv2
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torch.utils.data import Dataset, DataLoader
|
| 11 |
+
|
| 12 |
+
from torchvision.models.detection.transform import GeneralizedRCNNTransform
|
| 13 |
+
|
| 14 |
+
from albumentations import Compose, SmallestMaxSize, CenterCrop, Normalize, PadIfNeeded
|
| 15 |
+
from albumentations.pytorch import ToTensor
|
| 16 |
+
|
| 17 |
+
from dsfacedetector.face_ssd_infer import SSD
|
| 18 |
+
from tracker.iou_tracker import track_iou
|
| 19 |
+
from efficientnet_pytorch.model import EfficientNet, MBConvBlock
|
| 20 |
+
|
| 21 |
+
DETECTOR_WEIGHTS_PATH = 'WIDERFace_DSFD_RES152.fp16.pth'
|
| 22 |
+
DETECTOR_THRESHOLD = 0.3
|
| 23 |
+
DETECTOR_MIN_SIZE = 512
|
| 24 |
+
DETECTOR_MAX_SIZE = 512
|
| 25 |
+
DETECTOR_MEAN = (104.0, 117.0, 123.0)
|
| 26 |
+
DETECTOR_STD = (1.0, 1.0, 1.0)
|
| 27 |
+
DETECTOR_BATCH_SIZE = 16
|
| 28 |
+
DETECTOR_STEP = 3
|
| 29 |
+
|
| 30 |
+
TRACKER_SIGMA_L = 0.3
|
| 31 |
+
TRACKER_SIGMA_H = 0.9
|
| 32 |
+
TRACKER_SIGMA_IOU = 0.3
|
| 33 |
+
TRACKER_T_MIN = 7
|
| 34 |
+
|
| 35 |
+
VIDEO_MODEL_BBOX_MULT = 1.5
|
| 36 |
+
VIDEO_MODEL_MIN_SIZE = 224
|
| 37 |
+
VIDEO_MODEL_CROP_HEIGHT = 224
|
| 38 |
+
VIDEO_MODEL_CROP_WIDTH = 192
|
| 39 |
+
VIDEO_FACE_MODEL_TRACK_STEP = 2
|
| 40 |
+
VIDEO_SEQUENCE_MODEL_SEQUENCE_LENGTH = 7
|
| 41 |
+
VIDEO_SEQUENCE_MODEL_TRACK_STEP = 14
|
| 42 |
+
|
| 43 |
+
VIDEO_SEQUENCE_MODEL_WEIGHTS_PATH = 'efficientnet-b7_ns_seq_aa-original-mstd0.5_100k_v4_cad79a/snapshot_100000.fp16.pth'
|
| 44 |
+
FIRST_VIDEO_FACE_MODEL_WEIGHTS_PATH = 'efficientnet-b7_ns_aa-original-mstd0.5_large_crop_100k_v4_cad79a/snapshot_100000.fp16.pth'
|
| 45 |
+
SECOND_VIDEO_FACE_MODEL_WEIGHTS_PATH = 'efficientnet-b7_ns_aa-original-mstd0.5_re_100k_v4_cad79a/snapshot_100000.fp16.pth'
|
| 46 |
+
|
| 47 |
+
VIDEO_BATCH_SIZE = 1
|
| 48 |
+
VIDEO_TARGET_FPS = 15
|
| 49 |
+
VIDEO_NUM_WORKERS = 0
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class UnlabeledVideoDataset(Dataset):
|
| 53 |
+
def __init__(self, root_dir, content=None):
|
| 54 |
+
self.root_dir = os.path.normpath(root_dir)
|
| 55 |
+
if content is not None:
|
| 56 |
+
self.content = content
|
| 57 |
+
else:
|
| 58 |
+
self.content = []
|
| 59 |
+
for path in glob.iglob(os.path.join(self.root_dir, '**', '*.mp4'), recursive=True):
|
| 60 |
+
rel_path = path[len(self.root_dir) + 1:]
|
| 61 |
+
self.content.append(rel_path)
|
| 62 |
+
self.content = sorted(self.content)
|
| 63 |
+
|
| 64 |
+
def __len__(self):
|
| 65 |
+
return len(self.content)
|
| 66 |
+
|
| 67 |
+
def __getitem__(self, idx):
|
| 68 |
+
rel_path = self.content[idx]
|
| 69 |
+
path = os.path.join(self.root_dir, rel_path)
|
| 70 |
+
|
| 71 |
+
sample = {
|
| 72 |
+
'frames': [],
|
| 73 |
+
'index': idx
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
capture = cv2.VideoCapture(path)
|
| 77 |
+
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 78 |
+
if frame_count == 0:
|
| 79 |
+
return sample
|
| 80 |
+
|
| 81 |
+
fps = int(capture.get(cv2.CAP_PROP_FPS))
|
| 82 |
+
video_step = round(fps / VIDEO_TARGET_FPS)
|
| 83 |
+
if video_step == 0:
|
| 84 |
+
return sample
|
| 85 |
+
|
| 86 |
+
for i in range(frame_count):
|
| 87 |
+
capture.grab()
|
| 88 |
+
if i % video_step != 0:
|
| 89 |
+
continue
|
| 90 |
+
ret, frame = capture.retrieve()
|
| 91 |
+
if not ret:
|
| 92 |
+
continue
|
| 93 |
+
|
| 94 |
+
sample['frames'].append(frame)
|
| 95 |
+
|
| 96 |
+
return sample
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class Detector(object):
|
| 100 |
+
def __init__(self, weights_path):
|
| 101 |
+
self.model = SSD('test')
|
| 102 |
+
self.model.cuda().eval()
|
| 103 |
+
|
| 104 |
+
state = torch.load(weights_path, map_location=lambda storage, loc: storage)
|
| 105 |
+
state = {key: value.float() for key, value in state.items()}
|
| 106 |
+
self.model.load_state_dict(state)
|
| 107 |
+
|
| 108 |
+
self.transform = GeneralizedRCNNTransform(DETECTOR_MIN_SIZE, DETECTOR_MAX_SIZE, DETECTOR_MEAN, DETECTOR_STD)
|
| 109 |
+
self.transform.eval()
|
| 110 |
+
|
| 111 |
+
def detect(self, images):
|
| 112 |
+
images = torch.stack([torch.from_numpy(image).cuda() for image in images])
|
| 113 |
+
images = images.transpose(1, 3).transpose(2, 3).float()
|
| 114 |
+
original_image_sizes = [img.shape[-2:] for img in images]
|
| 115 |
+
images, _ = self.transform(images, None)
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
detections_batch = self.model(images.tensors).cpu().numpy()
|
| 118 |
+
result = []
|
| 119 |
+
for detections, image_size in zip(detections_batch, images.image_sizes):
|
| 120 |
+
scores = detections[1, :, 0]
|
| 121 |
+
keep_idxs = scores > DETECTOR_THRESHOLD
|
| 122 |
+
detections = detections[1, keep_idxs, :]
|
| 123 |
+
detections = detections[:, [1, 2, 3, 4, 0]]
|
| 124 |
+
detections[:, 0] *= image_size[1]
|
| 125 |
+
detections[:, 1] *= image_size[0]
|
| 126 |
+
detections[:, 2] *= image_size[1]
|
| 127 |
+
detections[:, 3] *= image_size[0]
|
| 128 |
+
result.append({
|
| 129 |
+
'scores': torch.from_numpy(detections[:, 4]),
|
| 130 |
+
'boxes': torch.from_numpy(detections[:, :4])
|
| 131 |
+
})
|
| 132 |
+
|
| 133 |
+
result = self.transform.postprocess(result, images.image_sizes, original_image_sizes)
|
| 134 |
+
return result
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def get_tracks(detections):
|
| 138 |
+
if len(detections) == 0:
|
| 139 |
+
return []
|
| 140 |
+
|
| 141 |
+
converted_detections = []
|
| 142 |
+
frame_bbox_to_face_idx = {}
|
| 143 |
+
for i, detections_per_frame in enumerate(detections):
|
| 144 |
+
converted_detections_per_frame = []
|
| 145 |
+
for j, (bbox, score) in enumerate(zip(detections_per_frame['boxes'], detections_per_frame['scores'])):
|
| 146 |
+
bbox = tuple(bbox.tolist())
|
| 147 |
+
frame_bbox_to_face_idx[(i, bbox)] = j
|
| 148 |
+
converted_detections_per_frame.append({'bbox': bbox, 'score': score})
|
| 149 |
+
converted_detections.append(converted_detections_per_frame)
|
| 150 |
+
|
| 151 |
+
tracks = track_iou(converted_detections, TRACKER_SIGMA_L, TRACKER_SIGMA_H, TRACKER_SIGMA_IOU, TRACKER_T_MIN)
|
| 152 |
+
tracks_converted = []
|
| 153 |
+
for track in tracks:
|
| 154 |
+
start_frame = track['start_frame'] - 1
|
| 155 |
+
bboxes = np.array(track['bboxes'], dtype=np.float32)
|
| 156 |
+
frame_indices = np.arange(start_frame, start_frame + len(bboxes)) * DETECTOR_STEP
|
| 157 |
+
interp_frame_indices = np.arange(frame_indices[0], frame_indices[-1] + 1)
|
| 158 |
+
interp_bboxes = np.zeros((len(interp_frame_indices), 4), dtype=np.float32)
|
| 159 |
+
for i in range(4):
|
| 160 |
+
interp_bboxes[:, i] = np.interp(interp_frame_indices, frame_indices, bboxes[:, i])
|
| 161 |
+
|
| 162 |
+
track_converted = []
|
| 163 |
+
for frame_idx, bbox in zip(interp_frame_indices, interp_bboxes):
|
| 164 |
+
track_converted.append((frame_idx, bbox))
|
| 165 |
+
tracks_converted.append(track_converted)
|
| 166 |
+
|
| 167 |
+
return tracks_converted
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class SeqExpandConv(nn.Module):
|
| 171 |
+
def __init__(self, in_channels, out_channels, seq_length):
|
| 172 |
+
super(SeqExpandConv, self).__init__()
|
| 173 |
+
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=(3, 1, 1), padding=(1, 0, 0), bias=False)
|
| 174 |
+
self.seq_length = seq_length
|
| 175 |
+
|
| 176 |
+
def forward(self, x):
|
| 177 |
+
batch_size, in_channels, height, width = x.shape
|
| 178 |
+
x = x.view(batch_size // self.seq_length, self.seq_length, in_channels, height, width)
|
| 179 |
+
x = self.conv(x.transpose(1, 2).contiguous()).transpose(2, 1).contiguous()
|
| 180 |
+
x = x.flatten(0, 1)
|
| 181 |
+
return x
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class TrackSequencesClassifier(object):
|
| 185 |
+
def __init__(self, weights_path):
|
| 186 |
+
model = EfficientNet.from_name('efficientnet-b7', override_params={'num_classes': 1})
|
| 187 |
+
|
| 188 |
+
for module in model.modules():
|
| 189 |
+
if isinstance(module, MBConvBlock):
|
| 190 |
+
if module._block_args.expand_ratio != 1:
|
| 191 |
+
expand_conv = module._expand_conv
|
| 192 |
+
seq_expand_conv = SeqExpandConv(expand_conv.in_channels, expand_conv.out_channels,
|
| 193 |
+
VIDEO_SEQUENCE_MODEL_SEQUENCE_LENGTH)
|
| 194 |
+
module._expand_conv = seq_expand_conv
|
| 195 |
+
self.model = model.cuda().eval()
|
| 196 |
+
|
| 197 |
+
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 198 |
+
self.transform = Compose(
|
| 199 |
+
[SmallestMaxSize(VIDEO_MODEL_MIN_SIZE), CenterCrop(VIDEO_MODEL_CROP_HEIGHT, VIDEO_MODEL_CROP_WIDTH),
|
| 200 |
+
normalize, ToTensor()])
|
| 201 |
+
|
| 202 |
+
state = torch.load(weights_path, map_location=lambda storage, loc: storage)
|
| 203 |
+
state = {key: value.float() for key, value in state.items()}
|
| 204 |
+
self.model.load_state_dict(state)
|
| 205 |
+
|
| 206 |
+
def classify(self, track_sequences):
|
| 207 |
+
track_sequences = [torch.stack([self.transform(image=face)['image'] for face in sequence]) for sequence in
|
| 208 |
+
track_sequences]
|
| 209 |
+
track_sequences = torch.cat(track_sequences).cuda()
|
| 210 |
+
with torch.no_grad():
|
| 211 |
+
track_probs = torch.sigmoid(self.model(track_sequences)).flatten().cpu().numpy()
|
| 212 |
+
|
| 213 |
+
return track_probs
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class TrackFacesClassifier(object):
|
| 217 |
+
def __init__(self, first_weights_path, second_weights_path):
|
| 218 |
+
first_model = EfficientNet.from_name('efficientnet-b7', override_params={'num_classes': 1})
|
| 219 |
+
self.first_model = first_model.cuda().eval()
|
| 220 |
+
second_model = EfficientNet.from_name('efficientnet-b7', override_params={'num_classes': 1})
|
| 221 |
+
self.second_model = second_model.cuda().eval()
|
| 222 |
+
|
| 223 |
+
first_normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 224 |
+
self.first_transform = Compose(
|
| 225 |
+
[SmallestMaxSize(VIDEO_MODEL_CROP_WIDTH), PadIfNeeded(VIDEO_MODEL_CROP_HEIGHT, VIDEO_MODEL_CROP_WIDTH),
|
| 226 |
+
CenterCrop(VIDEO_MODEL_CROP_HEIGHT, VIDEO_MODEL_CROP_WIDTH), first_normalize, ToTensor()])
|
| 227 |
+
|
| 228 |
+
second_normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 229 |
+
self.second_transform = Compose(
|
| 230 |
+
[SmallestMaxSize(VIDEO_MODEL_MIN_SIZE), CenterCrop(VIDEO_MODEL_CROP_HEIGHT, VIDEO_MODEL_CROP_WIDTH),
|
| 231 |
+
second_normalize, ToTensor()])
|
| 232 |
+
|
| 233 |
+
first_state = torch.load(first_weights_path, map_location=lambda storage, loc: storage)
|
| 234 |
+
first_state = {key: value.float() for key, value in first_state.items()}
|
| 235 |
+
self.first_model.load_state_dict(first_state)
|
| 236 |
+
|
| 237 |
+
second_state = torch.load(second_weights_path, map_location=lambda storage, loc: storage)
|
| 238 |
+
second_state = {key: value.float() for key, value in second_state.items()}
|
| 239 |
+
self.second_model.load_state_dict(second_state)
|
| 240 |
+
|
| 241 |
+
def classify(self, track_faces):
|
| 242 |
+
first_track_faces = []
|
| 243 |
+
second_track_faces = []
|
| 244 |
+
for i, face in enumerate(track_faces):
|
| 245 |
+
if i % 4 < 2:
|
| 246 |
+
first_track_faces.append(self.first_transform(image=face)['image'])
|
| 247 |
+
else:
|
| 248 |
+
second_track_faces.append(self.second_transform(image=face)['image'])
|
| 249 |
+
first_track_faces = torch.stack(first_track_faces).cuda()
|
| 250 |
+
second_track_faces = torch.stack(second_track_faces).cuda()
|
| 251 |
+
with torch.no_grad():
|
| 252 |
+
first_track_probs = torch.sigmoid(self.first_model(first_track_faces)).flatten().cpu().numpy()
|
| 253 |
+
second_track_probs = torch.sigmoid(self.second_model(second_track_faces)).flatten().cpu().numpy()
|
| 254 |
+
track_probs = np.concatenate((first_track_probs, second_track_probs))
|
| 255 |
+
|
| 256 |
+
return track_probs
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def extract_sequence(frames, start_idx, bbox, flip):
|
| 260 |
+
frame_height, frame_width, _ = frames[start_idx].shape
|
| 261 |
+
xmin, ymin, xmax, ymax = bbox
|
| 262 |
+
width = xmax - xmin
|
| 263 |
+
height = ymax - ymin
|
| 264 |
+
xcenter = xmin + width / 2
|
| 265 |
+
ycenter = ymin + height / 2
|
| 266 |
+
width = width * VIDEO_MODEL_BBOX_MULT
|
| 267 |
+
height = height * VIDEO_MODEL_BBOX_MULT
|
| 268 |
+
xmin = xcenter - width / 2
|
| 269 |
+
ymin = ycenter - height / 2
|
| 270 |
+
xmax = xmin + width
|
| 271 |
+
ymax = ymin + height
|
| 272 |
+
|
| 273 |
+
xmin = max(int(xmin), 0)
|
| 274 |
+
xmax = min(int(xmax), frame_width)
|
| 275 |
+
ymin = max(int(ymin), 0)
|
| 276 |
+
ymax = min(int(ymax), frame_height)
|
| 277 |
+
|
| 278 |
+
sequence = []
|
| 279 |
+
for i in range(VIDEO_SEQUENCE_MODEL_SEQUENCE_LENGTH):
|
| 280 |
+
face = cv2.cvtColor(frames[start_idx + i][ymin:ymax, xmin:xmax], cv2.COLOR_BGR2RGB)
|
| 281 |
+
sequence.append(face)
|
| 282 |
+
|
| 283 |
+
if flip:
|
| 284 |
+
sequence = [face[:, ::-1] for face in sequence]
|
| 285 |
+
|
| 286 |
+
return sequence
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def extract_face(frame, bbox, flip):
|
| 290 |
+
frame_height, frame_width, _ = frame.shape
|
| 291 |
+
xmin, ymin, xmax, ymax = bbox
|
| 292 |
+
width = xmax - xmin
|
| 293 |
+
height = ymax - ymin
|
| 294 |
+
xcenter = xmin + width / 2
|
| 295 |
+
ycenter = ymin + height / 2
|
| 296 |
+
width = width * VIDEO_MODEL_BBOX_MULT
|
| 297 |
+
height = height * VIDEO_MODEL_BBOX_MULT
|
| 298 |
+
xmin = xcenter - width / 2
|
| 299 |
+
ymin = ycenter - height / 2
|
| 300 |
+
xmax = xmin + width
|
| 301 |
+
ymax = ymin + height
|
| 302 |
+
|
| 303 |
+
xmin = max(int(xmin), 0)
|
| 304 |
+
xmax = min(int(xmax), frame_width)
|
| 305 |
+
ymin = max(int(ymin), 0)
|
| 306 |
+
ymax = min(int(ymax), frame_height)
|
| 307 |
+
|
| 308 |
+
face = cv2.cvtColor(frame[ymin:ymax, xmin:xmax], cv2.COLOR_BGR2RGB)
|
| 309 |
+
if flip:
|
| 310 |
+
face = face[:, ::-1].copy()
|
| 311 |
+
|
| 312 |
+
return face
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def main():
|
| 316 |
+
with open('config.yaml', 'r') as f:
|
| 317 |
+
config = yaml.load(f)
|
| 318 |
+
|
| 319 |
+
detector = Detector(os.path.join(config['MODELS_PATH'], DETECTOR_WEIGHTS_PATH))
|
| 320 |
+
track_sequences_classifier = TrackSequencesClassifier(os.path.join(config['MODELS_PATH'], VIDEO_SEQUENCE_MODEL_WEIGHTS_PATH))
|
| 321 |
+
track_faces_classifier = TrackFacesClassifier(os.path.join(config['MODELS_PATH'], FIRST_VIDEO_FACE_MODEL_WEIGHTS_PATH),
|
| 322 |
+
os.path.join(config['MODELS_PATH'], SECOND_VIDEO_FACE_MODEL_WEIGHTS_PATH))
|
| 323 |
+
|
| 324 |
+
dataset = UnlabeledVideoDataset(os.path.join(config['DFDC_DATA_PATH'], 'test_videos'))
|
| 325 |
+
print('Total number of videos: {}'.format(len(dataset)))
|
| 326 |
+
|
| 327 |
+
loader = DataLoader(dataset, batch_size=VIDEO_BATCH_SIZE, shuffle=False, num_workers=VIDEO_NUM_WORKERS,
|
| 328 |
+
collate_fn=lambda X: X,
|
| 329 |
+
drop_last=False)
|
| 330 |
+
|
| 331 |
+
video_name_to_score = {}
|
| 332 |
+
|
| 333 |
+
for video_sample in loader:
|
| 334 |
+
frames = video_sample[0]['frames']
|
| 335 |
+
detector_frames = frames[::DETECTOR_STEP]
|
| 336 |
+
video_idx = video_sample[0]['index']
|
| 337 |
+
video_rel_path = dataset.content[video_idx]
|
| 338 |
+
video_name = os.path.basename(video_rel_path)
|
| 339 |
+
|
| 340 |
+
if len(frames) == 0:
|
| 341 |
+
video_name_to_score[video_name] = 0.5
|
| 342 |
+
continue
|
| 343 |
+
|
| 344 |
+
detections = []
|
| 345 |
+
for start in range(0, len(detector_frames), DETECTOR_BATCH_SIZE):
|
| 346 |
+
end = min(len(detector_frames), start + DETECTOR_BATCH_SIZE)
|
| 347 |
+
detections_batch = detector.detect(detector_frames[start:end])
|
| 348 |
+
for detections_per_frame in detections_batch:
|
| 349 |
+
detections.append({key: value.cpu().numpy() for key, value in detections_per_frame.items()})
|
| 350 |
+
|
| 351 |
+
tracks = get_tracks(detections)
|
| 352 |
+
if len(tracks) == 0:
|
| 353 |
+
video_name_to_score[video_name] = 0.5
|
| 354 |
+
continue
|
| 355 |
+
|
| 356 |
+
sequence_track_scores = []
|
| 357 |
+
for track in tracks:
|
| 358 |
+
track_sequences = []
|
| 359 |
+
for i, (start_idx, _) in enumerate(
|
| 360 |
+
track[:-VIDEO_SEQUENCE_MODEL_SEQUENCE_LENGTH + 1:VIDEO_SEQUENCE_MODEL_TRACK_STEP]):
|
| 361 |
+
assert start_idx >= 0 and start_idx + VIDEO_SEQUENCE_MODEL_SEQUENCE_LENGTH <= len(frames)
|
| 362 |
+
_, bbox = track[i * VIDEO_SEQUENCE_MODEL_TRACK_STEP + VIDEO_SEQUENCE_MODEL_SEQUENCE_LENGTH // 2]
|
| 363 |
+
track_sequences.append(extract_sequence(frames, start_idx, bbox, i % 2 == 0))
|
| 364 |
+
sequence_track_scores.append(track_sequences_classifier.classify(track_sequences))
|
| 365 |
+
|
| 366 |
+
face_track_scores = []
|
| 367 |
+
for track in tracks:
|
| 368 |
+
track_faces = []
|
| 369 |
+
for i, (frame_idx, bbox) in enumerate(track[::VIDEO_FACE_MODEL_TRACK_STEP]):
|
| 370 |
+
face = extract_face(frames[frame_idx], bbox, i % 2 == 0)
|
| 371 |
+
track_faces.append(face)
|
| 372 |
+
face_track_scores.append(track_faces_classifier.classify(track_faces))
|
| 373 |
+
|
| 374 |
+
sequence_track_scores = np.concatenate(sequence_track_scores)
|
| 375 |
+
face_track_scores = np.concatenate(face_track_scores)
|
| 376 |
+
track_probs = np.concatenate((sequence_track_scores, face_track_scores))
|
| 377 |
+
|
| 378 |
+
delta = track_probs - 0.5
|
| 379 |
+
sign = np.sign(delta)
|
| 380 |
+
pos_delta = delta > 0
|
| 381 |
+
neg_delta = delta < 0
|
| 382 |
+
track_probs[pos_delta] = np.clip(0.5 + sign[pos_delta] * np.power(abs(delta[pos_delta]), 0.65), 0.01, 0.99)
|
| 383 |
+
track_probs[neg_delta] = np.clip(0.5 + sign[neg_delta] * np.power(abs(delta[neg_delta]), 0.65), 0.01, 0.99)
|
| 384 |
+
weights = np.power(abs(delta), 1.0) + 1e-4
|
| 385 |
+
video_score = float((track_probs * weights).sum() / weights.sum())
|
| 386 |
+
|
| 387 |
+
video_name_to_score[video_name] = video_score
|
| 388 |
+
print('NUM DETECTION FRAMES: {}, VIDEO SCORE: {}. {}'.format(len(detections), video_name_to_score[video_name],
|
| 389 |
+
video_rel_path))
|
| 390 |
+
|
| 391 |
+
os.makedirs(os.path.dirname(config['SUBMISSION_PATH']), exist_ok=True)
|
| 392 |
+
with open(config['SUBMISSION_PATH'], 'w') as f:
|
| 393 |
+
f.write('filename,label\n')
|
| 394 |
+
for video_name in sorted(video_name_to_score):
|
| 395 |
+
score = video_name_to_score[video_name]
|
| 396 |
+
f.write('{},{}\n'.format(video_name, score))
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
main()
|
tracker/__init__.py
ADDED
|
File without changes
|
tracker/iou_tracker.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Source: https://github.com/bochinski/iou-tracker
|
| 2 |
+
|
| 3 |
+
from .utils import iou
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def track_iou(detections, sigma_l, sigma_h, sigma_iou, t_min):
|
| 7 |
+
"""
|
| 8 |
+
Simple IOU based tracker.
|
| 9 |
+
See "High-Speed Tracking-by-Detection Without Using Image Information by E. Bochinski, V. Eiselein, T. Sikora" for
|
| 10 |
+
more information.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
detections (list): list of detections per frame, usually generated by util.load_mot
|
| 14 |
+
sigma_l (float): low detection threshold.
|
| 15 |
+
sigma_h (float): high detection threshold.
|
| 16 |
+
sigma_iou (float): IOU threshold.
|
| 17 |
+
t_min (float): minimum track length in frames.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
list: list of tracks.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
tracks_active = []
|
| 24 |
+
tracks_finished = []
|
| 25 |
+
|
| 26 |
+
for frame_num, detections_frame in enumerate(detections, start=1):
|
| 27 |
+
# apply low threshold to detections
|
| 28 |
+
dets = [det for det in detections_frame if det['score'] >= sigma_l]
|
| 29 |
+
|
| 30 |
+
updated_tracks = []
|
| 31 |
+
for track in tracks_active:
|
| 32 |
+
if len(dets) > 0:
|
| 33 |
+
# get det with highest iou
|
| 34 |
+
best_match = max(dets, key=lambda x: iou(track['bboxes'][-1], x['bbox']))
|
| 35 |
+
if iou(track['bboxes'][-1], best_match['bbox']) >= sigma_iou:
|
| 36 |
+
track['bboxes'].append(best_match['bbox'])
|
| 37 |
+
track['max_score'] = max(track['max_score'], best_match['score'])
|
| 38 |
+
|
| 39 |
+
updated_tracks.append(track)
|
| 40 |
+
|
| 41 |
+
# remove from best matching detection from detections
|
| 42 |
+
del dets[dets.index(best_match)]
|
| 43 |
+
|
| 44 |
+
# if track was not updated
|
| 45 |
+
if len(updated_tracks) == 0 or track is not updated_tracks[-1]:
|
| 46 |
+
# finish track when the conditions are met
|
| 47 |
+
if track['max_score'] >= sigma_h and len(track['bboxes']) >= t_min:
|
| 48 |
+
tracks_finished.append(track)
|
| 49 |
+
|
| 50 |
+
# create new tracks
|
| 51 |
+
new_tracks = [{'bboxes': [det['bbox']], 'max_score': det['score'], 'start_frame': frame_num} for det in dets]
|
| 52 |
+
tracks_active = updated_tracks + new_tracks
|
| 53 |
+
|
| 54 |
+
# finish all remaining active tracks
|
| 55 |
+
tracks_finished += [track for track in tracks_active
|
| 56 |
+
if track['max_score'] >= sigma_h and len(track['bboxes']) >= t_min]
|
| 57 |
+
|
| 58 |
+
return tracks_finished
|
tracker/utils.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def iou(bbox1, bbox2):
|
| 2 |
+
"""
|
| 3 |
+
Calculates the intersection-over-union of two bounding boxes.
|
| 4 |
+
|
| 5 |
+
Args:
|
| 6 |
+
bbox1 (numpy.array, list of floats): bounding box in format x1,y1,x2,y2.
|
| 7 |
+
bbox2 (numpy.array, list of floats): bounding box in format x1,y1,x2,y2.
|
| 8 |
+
|
| 9 |
+
Returns:
|
| 10 |
+
int: intersection-over-onion of bbox1, bbox2
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
bbox1 = [float(x) for x in bbox1]
|
| 14 |
+
bbox2 = [float(x) for x in bbox2]
|
| 15 |
+
|
| 16 |
+
(x0_1, y0_1, x1_1, y1_1) = bbox1
|
| 17 |
+
(x0_2, y0_2, x1_2, y1_2) = bbox2
|
| 18 |
+
|
| 19 |
+
# get the overlap rectangle
|
| 20 |
+
overlap_x0 = max(x0_1, x0_2)
|
| 21 |
+
overlap_y0 = max(y0_1, y0_2)
|
| 22 |
+
overlap_x1 = min(x1_1, x1_2)
|
| 23 |
+
overlap_y1 = min(y1_1, y1_2)
|
| 24 |
+
|
| 25 |
+
# check if there is an overlap
|
| 26 |
+
if overlap_x1 - overlap_x0 <= 0 or overlap_y1 - overlap_y0 <= 0:
|
| 27 |
+
return 0
|
| 28 |
+
|
| 29 |
+
# if yes, calculate the ratio of the overlap to each ROI size and the unified size
|
| 30 |
+
size_1 = (x1_1 - x0_1) * (y1_1 - y0_1)
|
| 31 |
+
size_2 = (x1_2 - x0_2) * (y1_2 - y0_2)
|
| 32 |
+
size_intersection = (overlap_x1 - overlap_x0) * (overlap_y1 - overlap_y0)
|
| 33 |
+
size_union = size_1 + size_2 - size_intersection
|
| 34 |
+
|
| 35 |
+
return size_intersection / size_union
|
train_b7_ns_aa_original_large_crop_100k.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import tqdm
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import distributions
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from torch.utils.data import DataLoader
|
| 13 |
+
|
| 14 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 15 |
+
|
| 16 |
+
import ffmpeg
|
| 17 |
+
|
| 18 |
+
from albumentations import ImageOnlyTransform
|
| 19 |
+
from albumentations import SmallestMaxSize, PadIfNeeded, HorizontalFlip, Normalize, Compose, RandomCrop
|
| 20 |
+
from albumentations.pytorch import ToTensor
|
| 21 |
+
from efficientnet_pytorch import EfficientNet
|
| 22 |
+
|
| 23 |
+
from timm.data.transforms_factory import transforms_imagenet_train
|
| 24 |
+
|
| 25 |
+
from datasets import TrackPairDataset
|
| 26 |
+
from extract_tracks_from_videos import TRACK_LENGTH, TRACKS_ROOT
|
| 27 |
+
from generate_track_pairs import TRACK_PAIRS_FILE_NAME
|
| 28 |
+
|
| 29 |
+
SEED = 30
|
| 30 |
+
BATCH_SIZE = 8
|
| 31 |
+
TRAIN_INDICES = [9, 13, 17, 21, 25, 29, 33, 37]
|
| 32 |
+
INITIAL_LR = 0.005
|
| 33 |
+
MOMENTUM = 0.9
|
| 34 |
+
WEIGHT_DECAY = 1e-4
|
| 35 |
+
NUM_WORKERS = 8
|
| 36 |
+
NUM_WARMUP_ITERATIONS = 100
|
| 37 |
+
SNAPSHOT_FREQUENCY = 1000
|
| 38 |
+
OUTPUT_FOLDER_NAME = 'efficientnet-b7_ns_aa-original-mstd0.5_large_crop_100k'
|
| 39 |
+
SNAPSHOT_NAME_TEMPLATE = 'snapshot_{}.pth'
|
| 40 |
+
MAX_ITERS = 100000
|
| 41 |
+
|
| 42 |
+
FPS_RANGE = (15, 30)
|
| 43 |
+
SCALE_RANGE = (0.25, 1)
|
| 44 |
+
CRF_RANGE = (17, 40)
|
| 45 |
+
TUNE_VALUES = ['film', 'animation', 'grain', 'stillimage', 'fastdecode', 'zerolatency']
|
| 46 |
+
|
| 47 |
+
CROP_HEIGHT = 224
|
| 48 |
+
CROP_WIDTH = 192
|
| 49 |
+
|
| 50 |
+
PRETRAINED_WEIGHTS_PATH = 'external_data/noisy_student_efficientnet-b7.pth'
|
| 51 |
+
SNAPSHOTS_ROOT = 'snapshots'
|
| 52 |
+
LOGS_ROOT = 'logs'
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class TrackTransform(object):
|
| 56 |
+
def __init__(self, fps_range, scale_range, crf_range, tune_values):
|
| 57 |
+
self.fps_range = fps_range
|
| 58 |
+
self.scale_range = scale_range
|
| 59 |
+
self.crf_range = crf_range
|
| 60 |
+
self.tune_values = tune_values
|
| 61 |
+
|
| 62 |
+
def get_params(self, src_fps, src_height, src_width):
|
| 63 |
+
if random.random() > 0.5:
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
dst_fps = src_fps
|
| 67 |
+
if random.random() > 0.5:
|
| 68 |
+
dst_fps = random.randrange(*self.fps_range)
|
| 69 |
+
|
| 70 |
+
scale = 1.0
|
| 71 |
+
if random.random() > 0.5:
|
| 72 |
+
scale = random.uniform(*self.scale_range)
|
| 73 |
+
|
| 74 |
+
dst_height = round(scale * src_height) // 2 * 2
|
| 75 |
+
dst_width = round(scale * src_width) // 2 * 2
|
| 76 |
+
|
| 77 |
+
crf = random.randrange(*self.crf_range)
|
| 78 |
+
tune = random.choice(self.tune_values)
|
| 79 |
+
|
| 80 |
+
return dst_fps, dst_height, dst_width, crf, tune
|
| 81 |
+
|
| 82 |
+
def __call__(self, track_path, src_fps, dst_fps, dst_height, dst_width, crf, tune):
|
| 83 |
+
out, err = (
|
| 84 |
+
ffmpeg
|
| 85 |
+
.input(os.path.join(track_path, '%d.png'), framerate=src_fps, start_number=0)
|
| 86 |
+
.filter('fps', fps=dst_fps)
|
| 87 |
+
.filter('scale', dst_width, dst_height)
|
| 88 |
+
.output('pipe:', format='h264', vcodec='libx264', crf=crf, tune=tune)
|
| 89 |
+
.run(capture_stdout=True, quiet=True)
|
| 90 |
+
)
|
| 91 |
+
out, err = (
|
| 92 |
+
ffmpeg
|
| 93 |
+
.input('pipe:', format='h264')
|
| 94 |
+
.output('pipe:', format='rawvideo', pix_fmt='rgb24')
|
| 95 |
+
.run(capture_stdout=True, input=out, quiet=True)
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
imgs = np.frombuffer(out, dtype=np.uint8).reshape(-1, dst_height, dst_width, 3)
|
| 99 |
+
|
| 100 |
+
return imgs
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class VisionTransform(ImageOnlyTransform):
|
| 104 |
+
def __init__(
|
| 105 |
+
self, transform, is_tensor=True, always_apply=False, p=1.0
|
| 106 |
+
):
|
| 107 |
+
super(VisionTransform, self).__init__(always_apply, p)
|
| 108 |
+
self.transform = transform
|
| 109 |
+
self.is_tensor = is_tensor
|
| 110 |
+
|
| 111 |
+
def apply(self, image, **params):
|
| 112 |
+
if self.is_tensor:
|
| 113 |
+
return self.transform(image)
|
| 114 |
+
else:
|
| 115 |
+
return np.array(self.transform(Image.fromarray(image)))
|
| 116 |
+
|
| 117 |
+
def get_transform_init_args_names(self):
|
| 118 |
+
return ("transform")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def set_global_seed(seed):
|
| 122 |
+
torch.manual_seed(seed)
|
| 123 |
+
if torch.cuda.is_available():
|
| 124 |
+
torch.cuda.manual_seed_all(seed)
|
| 125 |
+
random.seed(seed)
|
| 126 |
+
np.random.seed(seed)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def prepare_cudnn(deterministic=None, benchmark=None):
|
| 130 |
+
# https://pytorch.org/docs/stable/notes/randomness.html#cudnn
|
| 131 |
+
if deterministic is None:
|
| 132 |
+
deterministic = os.environ.get("CUDNN_DETERMINISTIC", "True") == "True"
|
| 133 |
+
torch.backends.cudnn.deterministic = deterministic
|
| 134 |
+
|
| 135 |
+
# https://discuss.pytorch.org/t/how-should-i-disable-using-cudnn-in-my-code/38053/4
|
| 136 |
+
if benchmark is None:
|
| 137 |
+
benchmark = os.environ.get("CUDNN_BENCHMARK", "True") == "True"
|
| 138 |
+
torch.backends.cudnn.benchmark = benchmark
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def main():
|
| 142 |
+
with open('config.yaml', 'r') as f:
|
| 143 |
+
config = yaml.load(f)
|
| 144 |
+
|
| 145 |
+
set_global_seed(SEED)
|
| 146 |
+
prepare_cudnn(deterministic=True, benchmark=True)
|
| 147 |
+
|
| 148 |
+
model = EfficientNet.from_name('efficientnet-b7', override_params={'num_classes': 1})
|
| 149 |
+
state = torch.load(PRETRAINED_WEIGHTS_PATH, map_location=lambda storage, loc: storage)
|
| 150 |
+
state.pop('_fc.weight')
|
| 151 |
+
state.pop('_fc.bias')
|
| 152 |
+
res = model.load_state_dict(state, strict=False)
|
| 153 |
+
assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
|
| 154 |
+
model = model.cuda()
|
| 155 |
+
|
| 156 |
+
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 157 |
+
_, rand_augment, _ = transforms_imagenet_train((CROP_HEIGHT, CROP_WIDTH), auto_augment='original-mstd0.5',
|
| 158 |
+
separate=True)
|
| 159 |
+
|
| 160 |
+
train_dataset = TrackPairDataset(os.path.join(config['ARTIFACTS_PATH'], TRACKS_ROOT),
|
| 161 |
+
os.path.join(config['ARTIFACTS_PATH'], TRACK_PAIRS_FILE_NAME),
|
| 162 |
+
TRAIN_INDICES,
|
| 163 |
+
track_length=TRACK_LENGTH,
|
| 164 |
+
track_transform=TrackTransform(FPS_RANGE, SCALE_RANGE, CRF_RANGE, TUNE_VALUES),
|
| 165 |
+
image_transform=Compose([
|
| 166 |
+
SmallestMaxSize(CROP_WIDTH),
|
| 167 |
+
PadIfNeeded(CROP_HEIGHT, CROP_WIDTH),
|
| 168 |
+
HorizontalFlip(),
|
| 169 |
+
RandomCrop(CROP_HEIGHT, CROP_WIDTH),
|
| 170 |
+
VisionTransform(rand_augment, is_tensor=False, p=0.5),
|
| 171 |
+
normalize,
|
| 172 |
+
ToTensor()
|
| 173 |
+
]), sequence_mode=False)
|
| 174 |
+
|
| 175 |
+
print('Train dataset size: {}.'.format(len(train_dataset)))
|
| 176 |
+
|
| 177 |
+
warmup_optimizer = torch.optim.SGD(model._fc.parameters(), INITIAL_LR, momentum=MOMENTUM,
|
| 178 |
+
weight_decay=WEIGHT_DECAY, nesterov=True)
|
| 179 |
+
|
| 180 |
+
full_optimizer = torch.optim.SGD(model.parameters(), INITIAL_LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY,
|
| 181 |
+
nesterov=True)
|
| 182 |
+
full_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(full_optimizer,
|
| 183 |
+
lambda iteration: (MAX_ITERS - iteration) / MAX_ITERS)
|
| 184 |
+
|
| 185 |
+
snapshots_root = os.path.join(config['ARTIFACTS_PATH'], SNAPSHOTS_ROOT, OUTPUT_FOLDER_NAME)
|
| 186 |
+
os.makedirs(snapshots_root)
|
| 187 |
+
log_root = os.path.join(config['ARTIFACTS_PATH'], LOGS_ROOT, OUTPUT_FOLDER_NAME)
|
| 188 |
+
os.makedirs(log_root)
|
| 189 |
+
|
| 190 |
+
writer = SummaryWriter(log_root)
|
| 191 |
+
|
| 192 |
+
iteration = 0
|
| 193 |
+
if iteration < NUM_WARMUP_ITERATIONS:
|
| 194 |
+
print('Start {} warmup iterations'.format(NUM_WARMUP_ITERATIONS))
|
| 195 |
+
model.eval()
|
| 196 |
+
model._fc.train()
|
| 197 |
+
for param in model.parameters():
|
| 198 |
+
param.requires_grad = False
|
| 199 |
+
for param in model._fc.parameters():
|
| 200 |
+
param.requires_grad = True
|
| 201 |
+
optimizer = warmup_optimizer
|
| 202 |
+
else:
|
| 203 |
+
print('Start without warmup iterations')
|
| 204 |
+
model.train()
|
| 205 |
+
optimizer = full_optimizer
|
| 206 |
+
|
| 207 |
+
max_lr = max(param_group["lr"] for param_group in full_optimizer.param_groups)
|
| 208 |
+
writer.add_scalar('train/max_lr', max_lr, iteration)
|
| 209 |
+
|
| 210 |
+
epoch = 0
|
| 211 |
+
fake_prob_dist = distributions.beta.Beta(0.5, 0.5)
|
| 212 |
+
while True:
|
| 213 |
+
epoch += 1
|
| 214 |
+
print('Epoch {} is in progress'.format(epoch))
|
| 215 |
+
loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, drop_last=True)
|
| 216 |
+
for samples in tqdm.tqdm(loader):
|
| 217 |
+
iteration += 1
|
| 218 |
+
fake_input_tensor = torch.cat(samples['fake']).cuda()
|
| 219 |
+
real_input_tensor = torch.cat(samples['real']).cuda()
|
| 220 |
+
target_fake_prob = fake_prob_dist.sample((len(fake_input_tensor),)).float().cuda()
|
| 221 |
+
fake_weight = target_fake_prob.view(-1, 1, 1, 1)
|
| 222 |
+
|
| 223 |
+
input_tensor = (1.0 - fake_weight) * real_input_tensor + fake_weight * fake_input_tensor
|
| 224 |
+
pred = model(input_tensor).flatten()
|
| 225 |
+
|
| 226 |
+
loss = F.binary_cross_entropy_with_logits(pred, target_fake_prob)
|
| 227 |
+
|
| 228 |
+
optimizer.zero_grad()
|
| 229 |
+
loss.backward()
|
| 230 |
+
optimizer.step()
|
| 231 |
+
if iteration > NUM_WARMUP_ITERATIONS:
|
| 232 |
+
full_lr_scheduler.step()
|
| 233 |
+
max_lr = max(param_group["lr"] for param_group in full_optimizer.param_groups)
|
| 234 |
+
writer.add_scalar('train/max_lr', max_lr, iteration)
|
| 235 |
+
|
| 236 |
+
writer.add_scalar('train/loss', loss.item(), iteration)
|
| 237 |
+
|
| 238 |
+
if iteration == NUM_WARMUP_ITERATIONS:
|
| 239 |
+
print('Stop warmup iterations')
|
| 240 |
+
model.train()
|
| 241 |
+
for param in model.parameters():
|
| 242 |
+
param.requires_grad = True
|
| 243 |
+
optimizer = full_optimizer
|
| 244 |
+
|
| 245 |
+
if iteration % SNAPSHOT_FREQUENCY == 0:
|
| 246 |
+
snapshot_name = SNAPSHOT_NAME_TEMPLATE.format(iteration)
|
| 247 |
+
snapshot_path = os.path.join(snapshots_root, snapshot_name)
|
| 248 |
+
print('Saving snapshot to {}'.format(snapshot_path))
|
| 249 |
+
torch.save(model.state_dict(), snapshot_path)
|
| 250 |
+
|
| 251 |
+
if iteration >= MAX_ITERS:
|
| 252 |
+
print('Stop training due to maximum iteration exceeded')
|
| 253 |
+
return
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
if __name__ == '__main__':
|
| 257 |
+
main()
|
train_b7_ns_aa_original_re_100k.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import tqdm
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import distributions
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from torch.utils.data import DataLoader
|
| 13 |
+
|
| 14 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 15 |
+
|
| 16 |
+
import ffmpeg
|
| 17 |
+
|
| 18 |
+
from albumentations import ImageOnlyTransform
|
| 19 |
+
from albumentations import SmallestMaxSize, HorizontalFlip, Normalize, Compose, RandomCrop
|
| 20 |
+
from albumentations.pytorch import ToTensor
|
| 21 |
+
from efficientnet_pytorch import EfficientNet
|
| 22 |
+
|
| 23 |
+
from timm.data.transforms_factory import transforms_imagenet_train
|
| 24 |
+
from timm.data.random_erasing import RandomErasing
|
| 25 |
+
|
| 26 |
+
from datasets import TrackPairDataset
|
| 27 |
+
from extract_tracks_from_videos import TRACK_LENGTH, TRACKS_ROOT
|
| 28 |
+
from generate_track_pairs import TRACK_PAIRS_FILE_NAME
|
| 29 |
+
|
| 30 |
+
SEED = 10
|
| 31 |
+
BATCH_SIZE = 8
|
| 32 |
+
TRAIN_INDICES = [9, 13, 17, 21, 25, 29, 33, 37]
|
| 33 |
+
INITIAL_LR = 0.005
|
| 34 |
+
MOMENTUM = 0.9
|
| 35 |
+
WEIGHT_DECAY = 1e-4
|
| 36 |
+
NUM_WORKERS = 8
|
| 37 |
+
NUM_WARMUP_ITERATIONS = 100
|
| 38 |
+
SNAPSHOT_FREQUENCY = 1000
|
| 39 |
+
OUTPUT_FOLDER_NAME = 'efficientnet-b7_ns_aa-original-mstd0.5_re_100k'
|
| 40 |
+
SNAPSHOT_NAME_TEMPLATE = 'snapshot_{}.pth'
|
| 41 |
+
MAX_ITERS = 100000
|
| 42 |
+
|
| 43 |
+
FPS_RANGE = (15, 30)
|
| 44 |
+
SCALE_RANGE = (0.25, 1)
|
| 45 |
+
CRF_RANGE = (17, 40)
|
| 46 |
+
TUNE_VALUES = ['film', 'animation', 'grain', 'stillimage', 'fastdecode', 'zerolatency']
|
| 47 |
+
|
| 48 |
+
RE_PROB = 0.2
|
| 49 |
+
RE_MODE = 'pixel'
|
| 50 |
+
RE_COUNT = 1
|
| 51 |
+
RE_NUM_SPLITS = 0
|
| 52 |
+
|
| 53 |
+
MIN_SIZE = 224
|
| 54 |
+
CROP_HEIGHT = 224
|
| 55 |
+
CROP_WIDTH = 192
|
| 56 |
+
|
| 57 |
+
PRETRAINED_WEIGHTS_PATH = 'external_data/noisy_student_efficientnet-b7.pth'
|
| 58 |
+
SNAPSHOTS_ROOT = 'snapshots'
|
| 59 |
+
LOGS_ROOT = 'logs'
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class TrackTransform(object):
|
| 63 |
+
def __init__(self, fps_range, scale_range, crf_range, tune_values):
|
| 64 |
+
self.fps_range = fps_range
|
| 65 |
+
self.scale_range = scale_range
|
| 66 |
+
self.crf_range = crf_range
|
| 67 |
+
self.tune_values = tune_values
|
| 68 |
+
|
| 69 |
+
def get_params(self, src_fps, src_height, src_width):
|
| 70 |
+
if random.random() > 0.5:
|
| 71 |
+
return None
|
| 72 |
+
|
| 73 |
+
dst_fps = src_fps
|
| 74 |
+
if random.random() > 0.5:
|
| 75 |
+
dst_fps = random.randrange(*self.fps_range)
|
| 76 |
+
|
| 77 |
+
scale = 1.0
|
| 78 |
+
if random.random() > 0.5:
|
| 79 |
+
scale = random.uniform(*self.scale_range)
|
| 80 |
+
|
| 81 |
+
dst_height = round(scale * src_height) // 2 * 2
|
| 82 |
+
dst_width = round(scale * src_width) // 2 * 2
|
| 83 |
+
|
| 84 |
+
crf = random.randrange(*self.crf_range)
|
| 85 |
+
tune = random.choice(self.tune_values)
|
| 86 |
+
|
| 87 |
+
return dst_fps, dst_height, dst_width, crf, tune
|
| 88 |
+
|
| 89 |
+
def __call__(self, track_path, src_fps, dst_fps, dst_height, dst_width, crf, tune):
|
| 90 |
+
out, err = (
|
| 91 |
+
ffmpeg
|
| 92 |
+
.input(os.path.join(track_path, '%d.png'), framerate=src_fps, start_number=0)
|
| 93 |
+
.filter('fps', fps=dst_fps)
|
| 94 |
+
.filter('scale', dst_width, dst_height)
|
| 95 |
+
.output('pipe:', format='h264', vcodec='libx264', crf=crf, tune=tune)
|
| 96 |
+
.run(capture_stdout=True, quiet=True)
|
| 97 |
+
)
|
| 98 |
+
out, err = (
|
| 99 |
+
ffmpeg
|
| 100 |
+
.input('pipe:', format='h264')
|
| 101 |
+
.output('pipe:', format='rawvideo', pix_fmt='rgb24')
|
| 102 |
+
.run(capture_stdout=True, input=out, quiet=True)
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
imgs = np.frombuffer(out, dtype=np.uint8).reshape(-1, dst_height, dst_width, 3)
|
| 106 |
+
|
| 107 |
+
return imgs
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class VisionTransform(ImageOnlyTransform):
|
| 111 |
+
def __init__(
|
| 112 |
+
self, transform, is_tensor=True, always_apply=False, p=1.0
|
| 113 |
+
):
|
| 114 |
+
super(VisionTransform, self).__init__(always_apply, p)
|
| 115 |
+
self.transform = transform
|
| 116 |
+
self.is_tensor = is_tensor
|
| 117 |
+
|
| 118 |
+
def apply(self, image, **params):
|
| 119 |
+
if self.is_tensor:
|
| 120 |
+
return self.transform(image)
|
| 121 |
+
else:
|
| 122 |
+
return np.array(self.transform(Image.fromarray(image)))
|
| 123 |
+
|
| 124 |
+
def get_transform_init_args_names(self):
|
| 125 |
+
return ("transform")
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def set_global_seed(seed):
|
| 129 |
+
torch.manual_seed(seed)
|
| 130 |
+
if torch.cuda.is_available():
|
| 131 |
+
torch.cuda.manual_seed_all(seed)
|
| 132 |
+
random.seed(seed)
|
| 133 |
+
np.random.seed(seed)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def prepare_cudnn(deterministic=None, benchmark=None):
|
| 137 |
+
# https://pytorch.org/docs/stable/notes/randomness.html#cudnn
|
| 138 |
+
if deterministic is None:
|
| 139 |
+
deterministic = os.environ.get("CUDNN_DETERMINISTIC", "True") == "True"
|
| 140 |
+
torch.backends.cudnn.deterministic = deterministic
|
| 141 |
+
|
| 142 |
+
# https://discuss.pytorch.org/t/how-should-i-disable-using-cudnn-in-my-code/38053/4
|
| 143 |
+
if benchmark is None:
|
| 144 |
+
benchmark = os.environ.get("CUDNN_BENCHMARK", "True") == "True"
|
| 145 |
+
torch.backends.cudnn.benchmark = benchmark
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def main():
|
| 149 |
+
with open('config.yaml', 'r') as f:
|
| 150 |
+
config = yaml.load(f)
|
| 151 |
+
|
| 152 |
+
set_global_seed(SEED)
|
| 153 |
+
prepare_cudnn(deterministic=True, benchmark=True)
|
| 154 |
+
|
| 155 |
+
model = EfficientNet.from_name('efficientnet-b7', override_params={'num_classes': 1})
|
| 156 |
+
state = torch.load(PRETRAINED_WEIGHTS_PATH, map_location=lambda storage, loc: storage)
|
| 157 |
+
state.pop('_fc.weight')
|
| 158 |
+
state.pop('_fc.bias')
|
| 159 |
+
res = model.load_state_dict(state, strict=False)
|
| 160 |
+
assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
|
| 161 |
+
model = model.cuda()
|
| 162 |
+
|
| 163 |
+
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 164 |
+
_, rand_augment, _ = transforms_imagenet_train((CROP_HEIGHT, CROP_WIDTH), auto_augment='original-mstd0.5',
|
| 165 |
+
separate=True)
|
| 166 |
+
|
| 167 |
+
train_dataset = TrackPairDataset(os.path.join(config['ARTIFACTS_PATH'], TRACKS_ROOT),
|
| 168 |
+
os.path.join(config['ARTIFACTS_PATH'], TRACK_PAIRS_FILE_NAME),
|
| 169 |
+
TRAIN_INDICES,
|
| 170 |
+
track_length=TRACK_LENGTH,
|
| 171 |
+
track_transform=TrackTransform(FPS_RANGE, SCALE_RANGE, CRF_RANGE, TUNE_VALUES),
|
| 172 |
+
image_transform=Compose([
|
| 173 |
+
SmallestMaxSize(MIN_SIZE),
|
| 174 |
+
HorizontalFlip(),
|
| 175 |
+
RandomCrop(CROP_HEIGHT, CROP_WIDTH),
|
| 176 |
+
VisionTransform(rand_augment, is_tensor=False, p=0.5),
|
| 177 |
+
normalize,
|
| 178 |
+
ToTensor(),
|
| 179 |
+
VisionTransform(
|
| 180 |
+
RandomErasing(probability=RE_PROB, mode=RE_MODE, max_count=RE_COUNT,
|
| 181 |
+
num_splits=RE_NUM_SPLITS, device='cpu'), is_tensor=True)
|
| 182 |
+
]), sequence_mode=False)
|
| 183 |
+
|
| 184 |
+
print('Train dataset size: {}.'.format(len(train_dataset)))
|
| 185 |
+
|
| 186 |
+
warmup_optimizer = torch.optim.SGD(model._fc.parameters(), INITIAL_LR, momentum=MOMENTUM,
|
| 187 |
+
weight_decay=WEIGHT_DECAY, nesterov=True)
|
| 188 |
+
|
| 189 |
+
full_optimizer = torch.optim.SGD(model.parameters(), INITIAL_LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY,
|
| 190 |
+
nesterov=True)
|
| 191 |
+
full_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(full_optimizer,
|
| 192 |
+
lambda iteration: (MAX_ITERS - iteration) / MAX_ITERS)
|
| 193 |
+
|
| 194 |
+
snapshots_root = os.path.join(config['ARTIFACTS_PATH'], SNAPSHOTS_ROOT, OUTPUT_FOLDER_NAME)
|
| 195 |
+
os.makedirs(snapshots_root)
|
| 196 |
+
log_root = os.path.join(config['ARTIFACTS_PATH'], LOGS_ROOT, OUTPUT_FOLDER_NAME)
|
| 197 |
+
os.makedirs(log_root)
|
| 198 |
+
|
| 199 |
+
writer = SummaryWriter(log_root)
|
| 200 |
+
|
| 201 |
+
iteration = 0
|
| 202 |
+
if iteration < NUM_WARMUP_ITERATIONS:
|
| 203 |
+
print('Start {} warmup iterations'.format(NUM_WARMUP_ITERATIONS))
|
| 204 |
+
model.eval()
|
| 205 |
+
model._fc.train()
|
| 206 |
+
for param in model.parameters():
|
| 207 |
+
param.requires_grad = False
|
| 208 |
+
for param in model._fc.parameters():
|
| 209 |
+
param.requires_grad = True
|
| 210 |
+
optimizer = warmup_optimizer
|
| 211 |
+
else:
|
| 212 |
+
print('Start without warmup iterations')
|
| 213 |
+
model.train()
|
| 214 |
+
optimizer = full_optimizer
|
| 215 |
+
|
| 216 |
+
max_lr = max(param_group["lr"] for param_group in full_optimizer.param_groups)
|
| 217 |
+
writer.add_scalar('train/max_lr', max_lr, iteration)
|
| 218 |
+
|
| 219 |
+
epoch = 0
|
| 220 |
+
fake_prob_dist = distributions.beta.Beta(0.5, 0.5)
|
| 221 |
+
while True:
|
| 222 |
+
epoch += 1
|
| 223 |
+
print('Epoch {} is in progress'.format(epoch))
|
| 224 |
+
loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, drop_last=True)
|
| 225 |
+
for samples in tqdm.tqdm(loader):
|
| 226 |
+
iteration += 1
|
| 227 |
+
fake_input_tensor = torch.cat(samples['fake']).cuda()
|
| 228 |
+
real_input_tensor = torch.cat(samples['real']).cuda()
|
| 229 |
+
target_fake_prob = fake_prob_dist.sample((len(fake_input_tensor),)).float().cuda()
|
| 230 |
+
fake_weight = target_fake_prob.view(-1, 1, 1, 1)
|
| 231 |
+
|
| 232 |
+
input_tensor = (1.0 - fake_weight) * real_input_tensor + fake_weight * fake_input_tensor
|
| 233 |
+
pred = model(input_tensor).flatten()
|
| 234 |
+
|
| 235 |
+
loss = F.binary_cross_entropy_with_logits(pred, target_fake_prob)
|
| 236 |
+
|
| 237 |
+
optimizer.zero_grad()
|
| 238 |
+
loss.backward()
|
| 239 |
+
optimizer.step()
|
| 240 |
+
if iteration > NUM_WARMUP_ITERATIONS:
|
| 241 |
+
full_lr_scheduler.step()
|
| 242 |
+
max_lr = max(param_group["lr"] for param_group in full_optimizer.param_groups)
|
| 243 |
+
writer.add_scalar('train/max_lr', max_lr, iteration)
|
| 244 |
+
|
| 245 |
+
writer.add_scalar('train/loss', loss.item(), iteration)
|
| 246 |
+
|
| 247 |
+
if iteration == NUM_WARMUP_ITERATIONS:
|
| 248 |
+
print('Stop warmup iterations')
|
| 249 |
+
model.train()
|
| 250 |
+
for param in model.parameters():
|
| 251 |
+
param.requires_grad = True
|
| 252 |
+
optimizer = full_optimizer
|
| 253 |
+
|
| 254 |
+
if iteration % SNAPSHOT_FREQUENCY == 0:
|
| 255 |
+
snapshot_name = SNAPSHOT_NAME_TEMPLATE.format(iteration)
|
| 256 |
+
snapshot_path = os.path.join(snapshots_root, snapshot_name)
|
| 257 |
+
print('Saving snapshot to {}'.format(snapshot_path))
|
| 258 |
+
torch.save(model.state_dict(), snapshot_path)
|
| 259 |
+
|
| 260 |
+
if iteration >= MAX_ITERS:
|
| 261 |
+
print('Stop training due to maximum iteration exceeded')
|
| 262 |
+
return
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
if __name__ == '__main__':
|
| 266 |
+
main()
|
train_b7_ns_seq_aa_original_100k.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import tqdm
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
from torch import distributions
|
| 12 |
+
from torch.nn import functional as F
|
| 13 |
+
from torch.utils.data import DataLoader
|
| 14 |
+
|
| 15 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 16 |
+
|
| 17 |
+
import ffmpeg
|
| 18 |
+
|
| 19 |
+
from albumentations import ImageOnlyTransform
|
| 20 |
+
from albumentations import SmallestMaxSize, HorizontalFlip, Normalize, Compose, RandomCrop
|
| 21 |
+
from albumentations.pytorch import ToTensor
|
| 22 |
+
from efficientnet_pytorch import EfficientNet
|
| 23 |
+
from efficientnet_pytorch.model import MBConvBlock
|
| 24 |
+
|
| 25 |
+
from timm.data.transforms_factory import transforms_imagenet_train
|
| 26 |
+
|
| 27 |
+
from datasets import TrackPairDataset
|
| 28 |
+
from extract_tracks_from_videos import TRACK_LENGTH, TRACKS_ROOT
|
| 29 |
+
from generate_track_pairs import TRACK_PAIRS_FILE_NAME
|
| 30 |
+
|
| 31 |
+
SEED = 20
|
| 32 |
+
BATCH_SIZE = 8
|
| 33 |
+
TRAIN_INDICES = [19, 21, 23, 25, 27, 29, 31]
|
| 34 |
+
INITIAL_LR = 0.005
|
| 35 |
+
MOMENTUM = 0.9
|
| 36 |
+
WEIGHT_DECAY = 1e-4
|
| 37 |
+
NUM_WORKERS = 8
|
| 38 |
+
NUM_WARMUP_ITERATIONS = 100
|
| 39 |
+
SNAPSHOT_FREQUENCY = 1000
|
| 40 |
+
OUTPUT_FOLDER_NAME = 'efficientnet-b7_ns_seq_aa-original-mstd0.5_100k'
|
| 41 |
+
SNAPSHOT_NAME_TEMPLATE = 'snapshot_{}.pth'
|
| 42 |
+
FINAL_SNAPSHOT_NAME = 'final.pth'
|
| 43 |
+
MAX_ITERS = 100000
|
| 44 |
+
|
| 45 |
+
FPS_RANGE = (15, 30)
|
| 46 |
+
SCALE_RANGE = (0.25, 1)
|
| 47 |
+
CRF_RANGE = (17, 40)
|
| 48 |
+
TUNE_VALUES = ['film', 'animation', 'grain', 'stillimage', 'fastdecode', 'zerolatency']
|
| 49 |
+
|
| 50 |
+
MIN_SIZE = 224
|
| 51 |
+
CROP_HEIGHT = 224
|
| 52 |
+
CROP_WIDTH = 192
|
| 53 |
+
|
| 54 |
+
PRETRAINED_WEIGHTS_PATH = 'external_data/noisy_student_efficientnet-b7.pth'
|
| 55 |
+
SNAPSHOTS_ROOT = 'snapshots'
|
| 56 |
+
LOGS_ROOT = 'logs'
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class SeqExpandConv(nn.Module):
|
| 60 |
+
def __init__(self, in_channels, out_channels, seq_length):
|
| 61 |
+
super(SeqExpandConv, self).__init__()
|
| 62 |
+
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=(3, 1, 1), padding=(1, 0, 0), bias=False)
|
| 63 |
+
self.seq_length = seq_length
|
| 64 |
+
|
| 65 |
+
def forward(self, x):
|
| 66 |
+
batch_size, in_channels, height, width = x.shape
|
| 67 |
+
x = x.view(batch_size // self.seq_length, self.seq_length, in_channels, height, width)
|
| 68 |
+
x = self.conv(x.transpose(1, 2).contiguous()).transpose(2, 1).contiguous()
|
| 69 |
+
x = x.flatten(0, 1)
|
| 70 |
+
return x
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class TrackTransform(object):
|
| 74 |
+
def __init__(self, fps_range, scale_range, crf_range, tune_values):
|
| 75 |
+
self.fps_range = fps_range
|
| 76 |
+
self.scale_range = scale_range
|
| 77 |
+
self.crf_range = crf_range
|
| 78 |
+
self.tune_values = tune_values
|
| 79 |
+
|
| 80 |
+
def get_params(self, src_fps, src_height, src_width):
|
| 81 |
+
if random.random() > 0.5:
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
dst_fps = src_fps
|
| 85 |
+
if random.random() > 0.5:
|
| 86 |
+
dst_fps = random.randrange(*self.fps_range)
|
| 87 |
+
|
| 88 |
+
scale = 1.0
|
| 89 |
+
if random.random() > 0.5:
|
| 90 |
+
scale = random.uniform(*self.scale_range)
|
| 91 |
+
|
| 92 |
+
dst_height = round(scale * src_height) // 2 * 2
|
| 93 |
+
dst_width = round(scale * src_width) // 2 * 2
|
| 94 |
+
|
| 95 |
+
crf = random.randrange(*self.crf_range)
|
| 96 |
+
tune = random.choice(self.tune_values)
|
| 97 |
+
|
| 98 |
+
return dst_fps, dst_height, dst_width, crf, tune
|
| 99 |
+
|
| 100 |
+
def __call__(self, track_path, src_fps, dst_fps, dst_height, dst_width, crf, tune):
|
| 101 |
+
out, err = (
|
| 102 |
+
ffmpeg
|
| 103 |
+
.input(os.path.join(track_path, '%d.png'), framerate=src_fps, start_number=0)
|
| 104 |
+
.filter('fps', fps=dst_fps)
|
| 105 |
+
.filter('scale', dst_width, dst_height)
|
| 106 |
+
.output('pipe:', format='h264', vcodec='libx264', crf=crf, tune=tune)
|
| 107 |
+
.run(capture_stdout=True, quiet=True)
|
| 108 |
+
)
|
| 109 |
+
out, err = (
|
| 110 |
+
ffmpeg
|
| 111 |
+
.input('pipe:', format='h264')
|
| 112 |
+
.output('pipe:', format='rawvideo', pix_fmt='rgb24')
|
| 113 |
+
.run(capture_stdout=True, input=out, quiet=True)
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
imgs = np.frombuffer(out, dtype=np.uint8).reshape(-1, dst_height, dst_width, 3)
|
| 117 |
+
|
| 118 |
+
return imgs
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class VisionTransform(ImageOnlyTransform):
|
| 122 |
+
def __init__(
|
| 123 |
+
self, transform, always_apply=False, p=1.0
|
| 124 |
+
):
|
| 125 |
+
super(VisionTransform, self).__init__(always_apply, p)
|
| 126 |
+
self.transform = transform
|
| 127 |
+
|
| 128 |
+
def apply(self, image, **params):
|
| 129 |
+
return np.array(self.transform(Image.fromarray(image)))
|
| 130 |
+
|
| 131 |
+
def get_transform_init_args_names(self):
|
| 132 |
+
return ("transform")
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def set_global_seed(seed):
|
| 136 |
+
torch.manual_seed(seed)
|
| 137 |
+
if torch.cuda.is_available():
|
| 138 |
+
torch.cuda.manual_seed_all(seed)
|
| 139 |
+
random.seed(seed)
|
| 140 |
+
np.random.seed(seed)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def prepare_cudnn(deterministic=None, benchmark=None):
|
| 144 |
+
# https://pytorch.org/docs/stable/notes/randomness.html#cudnn
|
| 145 |
+
if deterministic is None:
|
| 146 |
+
deterministic = os.environ.get("CUDNN_DETERMINISTIC", "True") == "True"
|
| 147 |
+
torch.backends.cudnn.deterministic = deterministic
|
| 148 |
+
|
| 149 |
+
# https://discuss.pytorch.org/t/how-should-i-disable-using-cudnn-in-my-code/38053/4
|
| 150 |
+
if benchmark is None:
|
| 151 |
+
benchmark = os.environ.get("CUDNN_BENCHMARK", "True") == "True"
|
| 152 |
+
torch.backends.cudnn.benchmark = benchmark
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def main():
|
| 156 |
+
with open('config.yaml', 'r') as f:
|
| 157 |
+
config = yaml.load(f)
|
| 158 |
+
|
| 159 |
+
set_global_seed(SEED)
|
| 160 |
+
prepare_cudnn(deterministic=True, benchmark=True)
|
| 161 |
+
|
| 162 |
+
model = EfficientNet.from_name('efficientnet-b7', override_params={'num_classes': 1})
|
| 163 |
+
state = torch.load(PRETRAINED_WEIGHTS_PATH, map_location=lambda storage, loc: storage)
|
| 164 |
+
state.pop('_fc.weight')
|
| 165 |
+
state.pop('_fc.bias')
|
| 166 |
+
res = model.load_state_dict(state, strict=False)
|
| 167 |
+
assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
|
| 168 |
+
|
| 169 |
+
for module in model.modules():
|
| 170 |
+
if isinstance(module, MBConvBlock):
|
| 171 |
+
if module._block_args.expand_ratio != 1:
|
| 172 |
+
expand_conv = module._expand_conv
|
| 173 |
+
seq_expand_conv = SeqExpandConv(expand_conv.in_channels, expand_conv.out_channels, len(TRAIN_INDICES))
|
| 174 |
+
seq_expand_conv.conv.weight.data[:, :, 0, :, :].copy_(expand_conv.weight.data / 3)
|
| 175 |
+
seq_expand_conv.conv.weight.data[:, :, 1, :, :].copy_(expand_conv.weight.data / 3)
|
| 176 |
+
seq_expand_conv.conv.weight.data[:, :, 2, :, :].copy_(expand_conv.weight.data / 3)
|
| 177 |
+
module._expand_conv = seq_expand_conv
|
| 178 |
+
|
| 179 |
+
model = model.cuda()
|
| 180 |
+
|
| 181 |
+
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 182 |
+
_, rand_augment, _ = transforms_imagenet_train((CROP_HEIGHT, CROP_WIDTH), auto_augment='original-mstd0.5',
|
| 183 |
+
separate=True)
|
| 184 |
+
|
| 185 |
+
train_dataset = TrackPairDataset(os.path.join(config['ARTIFACTS_PATH'], TRACKS_ROOT),
|
| 186 |
+
os.path.join(config['ARTIFACTS_PATH'], TRACK_PAIRS_FILE_NAME),
|
| 187 |
+
TRAIN_INDICES,
|
| 188 |
+
track_length=TRACK_LENGTH,
|
| 189 |
+
track_transform=TrackTransform(FPS_RANGE, SCALE_RANGE, CRF_RANGE, TUNE_VALUES),
|
| 190 |
+
image_transform=Compose([
|
| 191 |
+
SmallestMaxSize(MIN_SIZE),
|
| 192 |
+
HorizontalFlip(),
|
| 193 |
+
RandomCrop(CROP_HEIGHT, CROP_WIDTH),
|
| 194 |
+
VisionTransform(rand_augment, p=0.5),
|
| 195 |
+
normalize,
|
| 196 |
+
ToTensor()
|
| 197 |
+
]), sequence_mode=True)
|
| 198 |
+
|
| 199 |
+
print('Train dataset size: {}.'.format(len(train_dataset)))
|
| 200 |
+
|
| 201 |
+
warmup_optimizer = torch.optim.SGD(model._fc.parameters(), INITIAL_LR, momentum=MOMENTUM,
|
| 202 |
+
weight_decay=WEIGHT_DECAY, nesterov=True)
|
| 203 |
+
|
| 204 |
+
full_optimizer = torch.optim.SGD(model.parameters(), INITIAL_LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY,
|
| 205 |
+
nesterov=True)
|
| 206 |
+
full_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(full_optimizer,
|
| 207 |
+
lambda iteration: (MAX_ITERS - iteration) / MAX_ITERS)
|
| 208 |
+
|
| 209 |
+
snapshots_root = os.path.join(config['ARTIFACTS_PATH'], SNAPSHOTS_ROOT, OUTPUT_FOLDER_NAME)
|
| 210 |
+
os.makedirs(snapshots_root)
|
| 211 |
+
log_root = os.path.join(config['ARTIFACTS_PATH'], LOGS_ROOT, OUTPUT_FOLDER_NAME)
|
| 212 |
+
os.makedirs(log_root)
|
| 213 |
+
|
| 214 |
+
writer = SummaryWriter(log_root)
|
| 215 |
+
|
| 216 |
+
iteration = 0
|
| 217 |
+
if iteration < NUM_WARMUP_ITERATIONS:
|
| 218 |
+
print('Start {} warmup iterations'.format(NUM_WARMUP_ITERATIONS))
|
| 219 |
+
model.eval()
|
| 220 |
+
model._fc.train()
|
| 221 |
+
for param in model.parameters():
|
| 222 |
+
param.requires_grad = False
|
| 223 |
+
for param in model._fc.parameters():
|
| 224 |
+
param.requires_grad = True
|
| 225 |
+
optimizer = warmup_optimizer
|
| 226 |
+
else:
|
| 227 |
+
print('Start without warmup iterations')
|
| 228 |
+
model.train()
|
| 229 |
+
optimizer = full_optimizer
|
| 230 |
+
|
| 231 |
+
max_lr = max(param_group["lr"] for param_group in full_optimizer.param_groups)
|
| 232 |
+
writer.add_scalar('train/max_lr', max_lr, iteration)
|
| 233 |
+
|
| 234 |
+
epoch = 0
|
| 235 |
+
fake_prob_dist = distributions.beta.Beta(0.5, 0.5)
|
| 236 |
+
while True:
|
| 237 |
+
epoch += 1
|
| 238 |
+
print('Epoch {} is in progress'.format(epoch))
|
| 239 |
+
loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, drop_last=True)
|
| 240 |
+
for samples in tqdm.tqdm(loader):
|
| 241 |
+
iteration += 1
|
| 242 |
+
fake_input_tensor = torch.stack(samples['fake']).transpose(0, 1).cuda()
|
| 243 |
+
real_input_tensor = torch.stack(samples['real']).transpose(0, 1).cuda()
|
| 244 |
+
target_fake_prob = fake_prob_dist.sample((len(fake_input_tensor),)).float().cuda()
|
| 245 |
+
fake_weight = target_fake_prob.view(-1, 1, 1, 1, 1)
|
| 246 |
+
|
| 247 |
+
input_tensor = (1.0 - fake_weight) * real_input_tensor + fake_weight * fake_input_tensor
|
| 248 |
+
pred = model(input_tensor.flatten(0, 1)).flatten()
|
| 249 |
+
|
| 250 |
+
loss = F.binary_cross_entropy_with_logits(pred, target_fake_prob.repeat_interleave(len(TRAIN_INDICES)))
|
| 251 |
+
|
| 252 |
+
optimizer.zero_grad()
|
| 253 |
+
loss.backward()
|
| 254 |
+
optimizer.step()
|
| 255 |
+
if iteration > NUM_WARMUP_ITERATIONS:
|
| 256 |
+
full_lr_scheduler.step()
|
| 257 |
+
max_lr = max(param_group["lr"] for param_group in full_optimizer.param_groups)
|
| 258 |
+
writer.add_scalar('train/max_lr', max_lr, iteration)
|
| 259 |
+
|
| 260 |
+
writer.add_scalar('train/loss', loss.item(), iteration)
|
| 261 |
+
|
| 262 |
+
if iteration == NUM_WARMUP_ITERATIONS:
|
| 263 |
+
print('Stop warmup iterations')
|
| 264 |
+
model.train()
|
| 265 |
+
for param in model.parameters():
|
| 266 |
+
param.requires_grad = True
|
| 267 |
+
optimizer = full_optimizer
|
| 268 |
+
|
| 269 |
+
if iteration % SNAPSHOT_FREQUENCY == 0:
|
| 270 |
+
snapshot_name = SNAPSHOT_NAME_TEMPLATE.format(iteration)
|
| 271 |
+
snapshot_path = os.path.join(snapshots_root, snapshot_name)
|
| 272 |
+
print('Saving snapshot to {}'.format(snapshot_path))
|
| 273 |
+
torch.save(model.state_dict(), snapshot_path)
|
| 274 |
+
|
| 275 |
+
if iteration >= MAX_ITERS:
|
| 276 |
+
print('Stop training due to maximum iteration exceeded')
|
| 277 |
+
return
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
if __name__ == '__main__':
|
| 281 |
+
main()
|