Aryan6192 commited on
Commit
79cf6ef
·
verified ·
1 Parent(s): 814d96f
Files changed (44) hide show
  1. .gitattributes +2 -0
  2. .gitignore +11 -0
  3. Dockerfile +55 -0
  4. LICENSE +201 -0
  5. README.md +207 -10
  6. config.yaml +4 -0
  7. convert_tf_to_pt.sh +6 -0
  8. copy_weights.py +36 -0
  9. datasets.py +162 -0
  10. detect_faces_on_videos.py +82 -0
  11. dsfacedetector/__init__.py +0 -0
  12. dsfacedetector/data/__init__.py +0 -0
  13. dsfacedetector/data/config.py +57 -0
  14. dsfacedetector/face_ssd_infer.py +156 -0
  15. dsfacedetector/layers/__init__.py +3 -0
  16. dsfacedetector/layers/detection.py +157 -0
  17. dsfacedetector/layers/modules.py +98 -0
  18. dsfacedetector/layers/prior_box.py +133 -0
  19. dsfacedetector/utils.py +101 -0
  20. external_data/convert_tf_to_pt.py +174 -0
  21. external_data/original_tf/__init__.py +0 -0
  22. external_data/original_tf/efficientnet_builder.py +329 -0
  23. external_data/original_tf/efficientnet_model.py +713 -0
  24. external_data/original_tf/eval_ckpt_main.py +221 -0
  25. external_data/original_tf/preprocessing.py +241 -0
  26. external_data/original_tf/utils.py +405 -0
  27. extract_tracks_from_videos.py +105 -0
  28. generate_aligned_tracks.py +99 -0
  29. generate_track_pairs.py +70 -0
  30. generate_tracks.py +70 -0
  31. images/augmented_mixup.jpg +3 -0
  32. images/clip_example.jpg +0 -0
  33. images/first_and_second_model_inputs.jpg +0 -0
  34. images/mixup_example.jpg +3 -0
  35. images/pred_transform.jpg +0 -0
  36. images/third_model_input.jpg +0 -0
  37. models/.gitkeep +0 -0
  38. predict.py +399 -0
  39. tracker/__init__.py +0 -0
  40. tracker/iou_tracker.py +58 -0
  41. tracker/utils.py +35 -0
  42. train_b7_ns_aa_original_large_crop_100k.py +257 -0
  43. train_b7_ns_aa_original_re_100k.py +266 -0
  44. train_b7_ns_seq_aa_original_100k.py +281 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ images/augmented_mixup.jpg filter=lfs diff=lfs merge=lfs -text
37
+ images/mixup_example.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PyCharm
2
+ .idea
3
+
4
+ # Jupyter Notebook
5
+ .ipynb_checkpoints
6
+
7
+ # Byte-compiled / optimized / DLL files
8
+ __pycache__/
9
+ *.py[cod]
10
+ *$py.class
11
+
Dockerfile ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04
2
+
3
+ SHELL ["/bin/bash", "-c"]
4
+
5
+ RUN rm /etc/apt/sources.list.d/cuda.list \
6
+ /etc/apt/sources.list.d/nvidia-ml.list && \
7
+ apt-get update && \
8
+ DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
9
+ software-properties-common \
10
+ wget \
11
+ git && \
12
+ add-apt-repository -y ppa:deadsnakes/ppa && \
13
+ apt-get update && \
14
+ DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
15
+ python3.6 \
16
+ python3.6-dev && \
17
+ wget -O ~/get-pip.py \
18
+ https://bootstrap.pypa.io/get-pip.py && \
19
+ python3.6 ~/get-pip.py && \
20
+ pip3 --no-cache-dir install \
21
+ numpy==1.17.4 \
22
+ PyYAML==5.1.2 \
23
+ mkl==2019.0 \
24
+ mkl-include==2019.0 \
25
+ cmake==3.15.3 \
26
+ cffi==1.13.2 \
27
+ typing==3.7.4.1 \
28
+ six==1.13.0 \
29
+ Pillow==6.2.1 \
30
+ scipy==1.4.1 && \
31
+ cd /tmp && \
32
+ git clone https://github.com/pytorch/pytorch.git && \
33
+ cd pytorch && \
34
+ git checkout v1.3.0 && \
35
+ git submodule update --init --recursive && \
36
+ python3.6 setup.py install && \
37
+ cd /tmp && \
38
+ git clone https://github.com/pytorch/vision.git && \
39
+ cd vision && \
40
+ git checkout v0.4.1 && \
41
+ python3.6 setup.py install && \
42
+ DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
43
+ ffmpeg && \
44
+ pip3 --no-cache-dir install \
45
+ opencv-python==4.1.2.30 \
46
+ albumentations==0.4.3 \
47
+ tqdm==4.39.0 \
48
+ timm==0.1.18 \
49
+ efficientnet-pytorch==0.6.3 \
50
+ ffmpeg-python==0.2.0 \
51
+ tensorflow==1.15.2 && \
52
+ cd / && \
53
+ apt-get clean && \
54
+ apt-get autoremove && \
55
+ rm -rf /var/lib/apt/lists/* /tmp/*
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2020 N-TECH.LAB LTD
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,10 +1,207 @@
1
- ---
2
- title: Deep
3
- emoji: 🔥
4
- colorFrom: green
5
- colorTo: pink
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deepfake Detection Challenge
2
+ Solution for the [Deepfake Detection Challenge](https://www.kaggle.com/c/deepfake-detection-challenge).
3
+ Private LB score: **0.43452**
4
+ ## Solution description
5
+ ### Summary
6
+ Our solution consists of three EfficientNet-B7 models (we used the Noisy Student pre-trained weights). We did not use
7
+ external data, except for pre-trained weights. One model runs on frame sequences (a 3D convolution has been added to
8
+ each EfficientNet-B7 block). The other two models work frame-by-frame and differ in the size of the face crop and
9
+ augmentations during training. To tackle overfitting problem, we used mixup technique on aligned real-fake pairs. In
10
+ addition, we used the following augmentations: AutoAugment, Random Erasing, Random Crops, Random Flips, and various
11
+ video compression parameters. Video compression augmentation was done on-the-fly. To do this, short cropped tracks (50
12
+ frames each) were saved in PNG format, and at each training iteration they were loaded and reencoded with random
13
+ parameters using ffmpeg. Due to the mixup, model predictions were “uncertain”, so at the inference stage, model
14
+ confidence was strengthened by a simple transformation. The final prediction was obtained by averaging the predictions
15
+ of models with weights proportional to confidence.The total training and preprocessing time is approximately 5 days on
16
+ DGX-1.
17
+ ### Key ingredients
18
+ #### Mixup on aligned real-fake pairs
19
+ One of the main difficulties of this competition is a severe overfitting. Initially, all models overfitted in 2-3 epochs
20
+ (the validation loss started to increase). The idea, which helped a lot with the overfitting, is to train the model on
21
+ a mix of real and fake faces: for each fake face, we take the corresponding real face from the original video (with the
22
+ same box coordinates and the same frame number) an do a linear combination of them. In terms of tensor it’s
23
+ ```python
24
+ input_tensor = (1.0 - target) * real_input_tensor + target * fake_input_tensor
25
+ ```
26
+ where target is drawn from a Beta distribution with parameters alpha=beta=0.5. With these parameters, there is a very
27
+ high probability of picking values close to 0 or 1 (pure real or pure fake face). You can see the examples below:
28
+ ![mixup example](images/mixup_example.jpg "Mixup example")
29
+ Due to the fact that real and fake samples are aligned, the background remains almost unchanged on interpolated samples,
30
+ which reduces overfitting and makes the model pay more attention to the face.
31
+ #### Video compression augmentation
32
+ In the paper \[1\] it was pointed out that augmentations close to degradations seen in real-life video distributions
33
+ were applied to the test data. Specifically, these augmentations were (1) reduce the FPS of the video to 15; (2) reduce
34
+ the resolution of the video to 1/4 of its original size; and (3) reduce the overall encoding quality. In order to make
35
+ the model resistant to various parameters of video compression, we added augmentations with random parameters of video
36
+ encoding to training. It would be infeasible to apply such augmentations to the original videos on-the-fly during
37
+ training, so instead of the original videos, cropped (1.5x areas around the face) short (50 frames) clips were used.
38
+ Each clip was saved as separate frames in png format. An example of a clip is given below:
39
+ ![clip example](images/clip_example.jpg "Clip example")
40
+ For on-the-fly augmentation, ffmpeg-python was used. At each iteration, the following parameters were randomly sampled
41
+ (see \[2\]):
42
+ - FPS (15 to 30)
43
+ - scale (0.25 to 1.0)
44
+ - CRF (17 to 40)
45
+ - random tuning option
46
+ #### Model architecture
47
+ As a result of the experiments, we found out that the EfficientNet models work better than others (we checked ResNet,
48
+ ResNeXt, SE-ResNeXt). The best model was EfficientNet-B7 with Noisy Student pre-trained weights \[3\]. The size of the
49
+ input image is 224x192 (most of the faces in the training dataset are smaller). The final ensemble consists of three
50
+ models, two of which are frame-by-frame, and the third works on sequence.
51
+ ##### Frame-by-frame models
52
+ Frame-by-frame models work quite well. They differ in the size of the area around the face and augmentations during
53
+ training. Below are examples of input images for each of the models:
54
+ ![first and second model inputs](images/first_and_second_model_inputs.jpg "First and second model input examples")
55
+ ##### Sequence-based model
56
+ Probably, time dependencies can be useful for detecting fakes. Therefore, we added a 3d convolution to each block of the
57
+ EfficientNet model. This model worked slightly better than similar frame-by-frame model. The length of the input
58
+ sequence is 7 frames. The step between frames is 1/15 of a second. An example of an input sequence is given below:
59
+ ![third model input](images/third_model_input.jpg "Third model input example")
60
+ #### Image augmentations
61
+ To improve model generalization, we used the following augmentations: AutoAugment \[4\], Random Erasing, Random Crops,
62
+ Random Horizontal Flips. Since we used mixup, it was important to augment real-fake pairs the same way (see example).
63
+ For a sequence-based model, it was important to augment frames that belong to the same clip in the same way.
64
+ ![augmented mixup](images/augmented_mixup.jpg "Augmented mixup example")
65
+ #### Inference post-processing
66
+ Due to mixup, the predictions of the models were uncertain, which was not optimal for the logloss. To increase
67
+ confidence, we applied the following transformation:
68
+ ![prediction transform](images/pred_transform.jpg "Prediction transformation")
69
+ Due to computational limitations, predictions are made on a subsample of frames. Half of the frames were horizontally
70
+ flipped. The prediction for the video is obtained by averaging all the predictions with weights proportional to the
71
+ confidence (the closer the prediction to 0.5, the lower its weight). Such averaging works like attention, because the
72
+ model gives predictions close to 0.5 on poor quality frames (profile faces, blur, etc.).
73
+ #### References
74
+ \[1\] Brian Dolhansky, Russ Howes, Ben Pflaum, Nicole Baram, Cristian Canton Ferrer, “The Deepfake Detection Challenge
75
+ (DFDC) Preview Dataset”
76
+ \[2\] [https://trac.ffmpeg.org/wiki/Encode/H.264](https://trac.ffmpeg.org/wiki/Encode/H.264)
77
+ \[3\] Qizhe Xie, Minh-Thang Luong, Eduard Hovy, Quoc V. Le, “Self-training with Noisy Student improves ImageNet classification”
78
+ \[4\] Ekin D. Cubuk, Barret Zoph, Dandelion Mane, Vijay Vasudevan, Quoc V. Le, “AutoAugment: Learning Augmentation Policies from Data”
79
+ ## The hardware we used
80
+ - CPU: Intel(R) Xeon(R) CPU E5-2698 v4 @ 2.20GHz
81
+ - GPU: 8x NVIDIA Tesla V100 SXM2 32 GB
82
+ - RAM: 512 GB
83
+ - SSD: 6 TB
84
+ ## Prerequisites
85
+ ### Environment
86
+ Use the docker to get an environment close to what was used in the training. Run the following command to build the docker image:
87
+ ```bash
88
+ cd path/to/solution
89
+ sudo docker build -t dfdc .
90
+ ```
91
+ ### Data
92
+ Download the [deepfake-detection-challenge-data](https://www.kaggle.com/c/deepfake-detection-challenge/data) and extract all files to `/path/to/dfdc-data`. This directory must have the following structure:
93
+ ```
94
+ dfdc-data
95
+ ├── dfdc_train_part_0
96
+ ├── dfdc_train_part_1
97
+ ├── dfdc_train_part_10
98
+ ├── dfdc_train_part_11
99
+ ├── dfdc_train_part_12
100
+ ├── dfdc_train_part_13
101
+ ├── dfdc_train_part_14
102
+ ├── dfdc_train_part_15
103
+ ├── dfdc_train_part_16
104
+ ├── dfdc_train_part_17
105
+ ├── dfdc_train_part_18
106
+ ├── dfdc_train_part_19
107
+ ├── dfdc_train_part_2
108
+ ├── dfdc_train_part_20
109
+ ├── dfdc_train_part_21
110
+ ├── dfdc_train_part_22
111
+ ├── dfdc_train_part_23
112
+ ├── dfdc_train_part_24
113
+ ├── dfdc_train_part_25
114
+ ├── dfdc_train_part_26
115
+ ├── dfdc_train_part_27
116
+ ├── dfdc_train_part_28
117
+ ├── dfdc_train_part_29
118
+ ├── dfdc_train_part_3
119
+ ├── dfdc_train_part_30
120
+ ├── dfdc_train_part_31
121
+ ├── dfdc_train_part_32
122
+ ├── dfdc_train_part_33
123
+ ├── dfdc_train_part_34
124
+ ├── dfdc_train_part_35
125
+ ├── dfdc_train_part_36
126
+ ├── dfdc_train_part_37
127
+ ├── dfdc_train_part_38
128
+ ├── dfdc_train_part_39
129
+ ├── dfdc_train_part_4
130
+ ├── dfdc_train_part_40
131
+ ├── dfdc_train_part_41
132
+ ├── dfdc_train_part_42
133
+ ├── dfdc_train_part_43
134
+ ├── dfdc_train_part_44
135
+ ├── dfdc_train_part_45
136
+ ├── dfdc_train_part_46
137
+ ├── dfdc_train_part_47
138
+ ├── dfdc_train_part_48
139
+ ├── dfdc_train_part_49
140
+ ├── dfdc_train_part_5
141
+ ├── dfdc_train_part_6
142
+ ├── dfdc_train_part_7
143
+ ├── dfdc_train_part_8
144
+ ├── dfdc_train_part_9
145
+ └── test_videos
146
+ ```
147
+
148
+ ### External data
149
+ According to the rules of the competition, external data is allowed. The solution does not use other external data, except for pre-trained models. Below is a table with information about these models.
150
+
151
+ | File Name | Source | Direct Link | Forum Post |
152
+ | --------- | ------ | ----------- | ---------- |
153
+ | WIDERFace_DSFD_RES152.pth | [github](https://github.com/Tencent/FaceDetection-DSFD/tree/31aa8bdeaf01a0c408adaf2709754a16b17aec79) | [google drive](https://drive.google.com/file/d/1WeXlNYsM6dMP3xQQELI-4gxhwKUQxc3-/view) | [link](https://www.kaggle.com/c/deepfake-detection-challenge/discussion/121203#761391) |
154
+ | noisy_student_efficientnet-b7.tar.gz | [github](https://github.com/tensorflow/tpu/tree/4719695c9128622fb26dedb19ea19bd9d1ee3177/models/official/efficientnet) | [link](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b7.tar.gz) | [link](https://www.kaggle.com/c/deepfake-detection-challenge/discussion/121203#748358) |
155
+
156
+ Download these files and copy them to the `external_data` folder.
157
+
158
+ ## How to train the model
159
+ Run the docker container with the paths correctly mounted:
160
+ ```bash
161
+ sudo docker run --runtime=nvidia -i -t -d --rm --ipc=host -v /path/to/dfdc-data:/kaggle/input/deepfake-detection-challenge:ro -v /path/to/solution:/kaggle/solution --name dfdc dfdc
162
+ sudo docker exec -it dfdc /bin/bash
163
+ cd /kaggle/solution
164
+ ```
165
+ Convert pre-trained model from tensorflow to pytorch:
166
+ ```bash
167
+ bash convert_tf_to_pt.sh
168
+ ```
169
+ Detect faces on videos:
170
+ ```bash
171
+ python3.6 detect_faces_on_videos.py
172
+ ```
173
+ _Note: You can parallelize this operation using the `--part` and `--num_parts` arguments_
174
+ Generate tracks:
175
+ ```bash
176
+ python3.6 generate_tracks.py
177
+ ```
178
+ Generate aligned tracks:
179
+ ```bash
180
+ python3.6 generate_aligned_tracks.py
181
+ ```
182
+ Extract tracks from videos:
183
+ ```bash
184
+ python3.6 extract_tracks_from_videos.py
185
+ ```
186
+ _Note: You can parallelize this operation using the `--part` and `--num_parts` arguments_
187
+ Generate track pairs:
188
+ ```bash
189
+ python3.6 generate_track_pairs.py
190
+ ```
191
+ Train models:
192
+ ```bash
193
+ python3.6 train_b7_ns_aa_original_large_crop_100k.py
194
+ python3.6 train_b7_ns_aa_original_re_100k.py
195
+ python3.6 train_b7_ns_seq_aa_original_100k.py
196
+ ```
197
+ Copy the final weights and convert them to FP16:
198
+ ```bash
199
+ python3.6 copy_weights.py
200
+ ```
201
+ ## Serialized copy of the trained model
202
+ You can download the final weights that were used in the competition (the result of the `copy_weights.py` script): [GoogleDrive](https://drive.google.com/file/d/1S-HeppZcbXDF0F-BO96zhqZyrRWOaan6/view?usp=sharing)
203
+ ## How to generate submission
204
+ Run the following command
205
+ ```bash
206
+ python3.6 predict.py
207
+ ```
config.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ DFDC_DATA_PATH: "/kaggle/input/deepfake-detection-challenge"
2
+ ARTIFACTS_PATH: "/kaggle/solution/artifacts"
3
+ MODELS_PATH: "/kaggle/solution/models"
4
+ SUBMISSION_PATH: "/kaggle/solution/output/submission.csv"
convert_tf_to_pt.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ cd external_data && \
3
+ tar -xzf noisy_student_efficientnet-b7.tar.gz && \
4
+ python3.6 convert_tf_to_pt.py --model_name efficientnet-b7 --tf_checkpoint noisy-student-efficientnet-b7 --output_file noisy_student_efficientnet-b7.pth && \
5
+ rm -rf noisy-student-efficientnet-b7 tmp && \
6
+ cd ..
copy_weights.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+
4
+ import torch
5
+
6
+ WEIGHTS_MAPPING = {
7
+ 'snapshots/efficientnet-b7_ns_aa-original-mstd0.5_large_crop_100k/snapshot_100000.pth': 'efficientnet-b7_ns_aa-original-mstd0.5_large_crop_100k_v4_cad79a/snapshot_100000.fp16.pth',
8
+ 'snapshots/efficientnet-b7_ns_aa-original-mstd0.5_re_100k/snapshot_100000.pth': 'efficientnet-b7_ns_aa-original-mstd0.5_re_100k_v4_cad79a/snapshot_100000.fp16.pth',
9
+ 'snapshots/efficientnet-b7_ns_seq_aa-original-mstd0.5_100k/snapshot_100000.pth': 'efficientnet-b7_ns_seq_aa-original-mstd0.5_100k_v4_cad79a/snapshot_100000.fp16.pth'
10
+ }
11
+
12
+ SRC_DETECTOR_WEIGHTS = 'external_data/WIDERFace_DSFD_RES152.pth'
13
+ DST_DETECTOR_WEIGHTS = 'WIDERFace_DSFD_RES152.fp16.pth'
14
+
15
+
16
+ def copy_weights(src_path, dst_path):
17
+ state = torch.load(src_path, map_location=lambda storage, loc: storage)
18
+ state = {key: value.half() for key, value in state.items()}
19
+ os.makedirs(os.path.dirname(dst_path), exist_ok=True)
20
+ torch.save(state, dst_path)
21
+
22
+
23
+ def main():
24
+ with open('config.yaml', 'r') as f:
25
+ config = yaml.load(f)
26
+
27
+ for src_rel_path, dst_rel_path in WEIGHTS_MAPPING.items():
28
+ src_path = os.path.join(config['ARTIFACTS_PATH'], src_rel_path)
29
+ dst_path = os.path.join(config['MODELS_PATH'], dst_rel_path)
30
+ copy_weights(src_path, dst_path)
31
+
32
+ copy_weights(SRC_DETECTOR_WEIGHTS, os.path.join(config['MODELS_PATH'], DST_DETECTOR_WEIGHTS))
33
+
34
+
35
+ if __name__ == '__main__':
36
+ main()
datasets.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import glob
4
+
5
+ import cv2
6
+ import numpy as np
7
+
8
+ from torch.utils.data import Dataset
9
+
10
+
11
+ class UnlabeledVideoDataset(Dataset):
12
+ def __init__(self, root_dir, content=None, transform=None):
13
+ self.root_dir = os.path.normpath(root_dir)
14
+ self.transform = transform
15
+
16
+ if content is not None:
17
+ self.content = content
18
+ else:
19
+ self.content = []
20
+ for path in glob.iglob(os.path.join(self.root_dir, '**', '*.mp4'), recursive=True):
21
+ rel_path = path[len(self.root_dir) + 1:]
22
+ self.content.append(rel_path)
23
+ self.content = sorted(self.content)
24
+
25
+ def __len__(self):
26
+ return len(self.content)
27
+
28
+ def __getitem__(self, idx):
29
+ rel_path = self.content[idx]
30
+ path = os.path.join(self.root_dir, rel_path)
31
+
32
+ capture = cv2.VideoCapture(path)
33
+
34
+ frames = []
35
+ if capture.isOpened():
36
+ while True:
37
+ ret, frame = capture.read()
38
+ if not ret:
39
+ break
40
+
41
+ if self.transform is not None:
42
+ frame = self.transform(frame)
43
+
44
+ frames.append(frame)
45
+
46
+ sample = {
47
+ 'frames': frames,
48
+ 'index': idx
49
+ }
50
+
51
+ return sample
52
+
53
+
54
+ class FaceDataset(Dataset):
55
+ def __init__(self, root_dir, content, labels=None, transform=None):
56
+ self.root_dir = os.path.normpath(root_dir)
57
+ self.content = content
58
+ self.labels = labels
59
+ self.transform = transform
60
+
61
+ def __len__(self):
62
+ return len(self.content)
63
+
64
+ def __getitem__(self, idx):
65
+ rel_path = self.content[idx]
66
+ path = os.path.join(self.root_dir, rel_path)
67
+
68
+ face = cv2.imread(path, cv2.IMREAD_COLOR)
69
+ face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB)
70
+
71
+ if self.transform is not None:
72
+ face = self.transform(image=face)['image']
73
+
74
+ sample = {
75
+ 'face': face,
76
+ 'index': idx
77
+ }
78
+
79
+ if self.labels is not None:
80
+ sample['label'] = self.labels[idx]
81
+
82
+ return sample
83
+
84
+
85
+ class TrackPairDataset(Dataset):
86
+ FPS = 30
87
+
88
+ def __init__(self, tracks_root, pairs_path, indices, track_length, track_transform=None, image_transform=None,
89
+ sequence_mode=True):
90
+ self.tracks_root = os.path.normpath(tracks_root)
91
+ self.track_transform = track_transform
92
+ self.image_transform = image_transform
93
+ self.indices = np.asarray(indices, dtype=np.int32)
94
+ self.track_length = track_length
95
+ self.sequence_mode = sequence_mode
96
+
97
+ self.pairs = []
98
+ with open(pairs_path, 'r') as f:
99
+ for line in f:
100
+ real_track, fake_track = line.strip().split(',')
101
+ self.pairs.append((real_track, fake_track))
102
+
103
+ def __len__(self):
104
+ return len(self.pairs)
105
+
106
+ def __getitem__(self, idx):
107
+ real_track_path, fake_track_path = self.pairs[idx]
108
+
109
+ real_track_path = os.path.join(self.tracks_root, real_track_path)
110
+ fake_track_path = os.path.join(self.tracks_root, fake_track_path)
111
+
112
+ if self.track_transform is not None:
113
+ img = self.load_img(real_track_path, 0)
114
+ src_height, src_width = img.shape[:2]
115
+ track_transform_params = self.track_transform.get_params(self.FPS, src_height, src_width)
116
+ else:
117
+ track_transform_params = None
118
+
119
+ real_track = self.load_track(real_track_path, self.indices, track_transform_params)
120
+ fake_track = self.load_track(fake_track_path, self.indices, track_transform_params)
121
+
122
+ if self.image_transform is not None:
123
+ prev_state = random.getstate()
124
+ transformed_real_track = []
125
+ for img in real_track:
126
+ if self.sequence_mode:
127
+ random.setstate(prev_state)
128
+ transformed_real_track.append(self.image_transform(image=img)['image'])
129
+
130
+ real_track = transformed_real_track
131
+
132
+ random.setstate(prev_state)
133
+ transformed_fake_track = []
134
+ for img in fake_track:
135
+ if self.sequence_mode:
136
+ random.setstate(prev_state)
137
+ transformed_fake_track.append(self.image_transform(image=img)['image'])
138
+ fake_track = transformed_fake_track
139
+
140
+ sample = {
141
+ 'real': real_track,
142
+ 'fake': fake_track
143
+ }
144
+
145
+ return sample
146
+
147
+ def load_img(self, track_path, idx):
148
+ img = cv2.imread(os.path.join(track_path, '{}.png'.format(idx)))
149
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
150
+
151
+ return img
152
+
153
+ def load_track(self, track_path, indices, transform_params):
154
+ if transform_params is None:
155
+ track = np.stack([self.load_img(track_path, idx) for idx in indices])
156
+ else:
157
+ track = self.track_transform(track_path, self.FPS, *transform_params)
158
+ indices = (indices.astype(np.float32) / self.track_length) * len(track)
159
+ indices = np.round(indices).astype(np.int32).clip(0, len(track) - 1)
160
+ track = track[indices]
161
+
162
+ return track
detect_faces_on_videos.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import glob
4
+ import yaml
5
+ import pickle
6
+ import tqdm
7
+
8
+ import torch
9
+ from torch.utils.data import DataLoader
10
+
11
+ from dsfacedetector.face_ssd_infer import SSD
12
+ from datasets import UnlabeledVideoDataset
13
+
14
+ DETECTOR_WEIGHTS_PATH = 'external_data/WIDERFace_DSFD_RES152.pth'
15
+ DETECTOR_THRESHOLD = 0.3
16
+ DETECTOR_STEP = 6
17
+ DETECTOR_TARGET_SIZE = (512, 512)
18
+
19
+ BATCH_SIZE = 1
20
+ NUM_WORKERS = 0
21
+
22
+ DETECTIONS_ROOT = 'detections'
23
+ DETECTIONS_FILE_NAME = 'detections.pkl'
24
+
25
+
26
+ def main():
27
+ parser = argparse.ArgumentParser(description='Detects faces on videos')
28
+ parser.add_argument('--num_parts', type=int, default=1, help='Number of parts')
29
+ parser.add_argument('--part', type=int, default=0, help='Part index')
30
+
31
+ args = parser.parse_args()
32
+
33
+ with open('config.yaml', 'r') as f:
34
+ config = yaml.load(f)
35
+
36
+ content = []
37
+ for path in glob.iglob(os.path.join(config['DFDC_DATA_PATH'], 'dfdc_train_part_*', '*.mp4')):
38
+ parts = path.split('/')
39
+ content.append('/'.join(parts[-2:]))
40
+ content = sorted(content)
41
+
42
+ print('Total number of videos: {}'.format(len(content)))
43
+
44
+ part_size = len(content) // args.num_parts + 1
45
+ assert part_size * args.num_parts >= len(content)
46
+ part_start = part_size * args.part
47
+ part_end = min(part_start + part_size, len(content))
48
+ print('Part {} ({}, {})'.format(args.part, part_start, part_end))
49
+
50
+ dataset = UnlabeledVideoDataset(config['DFDC_DATA_PATH'], content[part_start:part_end])
51
+
52
+ detector = SSD('test')
53
+ state = torch.load(DETECTOR_WEIGHTS_PATH, map_location=lambda storage, loc: storage)
54
+ detector.load_state_dict(state)
55
+ device = torch.device('cuda')
56
+ detector = detector.eval().to(device)
57
+
58
+ loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, collate_fn=lambda X: X,
59
+ drop_last=False)
60
+
61
+ dst_root = os.path.join(config['ARTIFACTS_PATH'], DETECTIONS_ROOT)
62
+ os.makedirs(dst_root, exist_ok=True)
63
+
64
+ for video_sample in tqdm.tqdm(loader):
65
+ frames = video_sample[0]['frames']
66
+ video_idx = video_sample[0]['index']
67
+ video_rel_path = dataset.content[video_idx]
68
+
69
+ detections = []
70
+ for frame in frames[::DETECTOR_STEP]:
71
+ with torch.no_grad():
72
+ detections_per_frame = detector.detect_on_image(frame, DETECTOR_TARGET_SIZE, device, is_pad=False,
73
+ keep_thresh=DETECTOR_THRESHOLD)
74
+ detections.append({'boxes': detections_per_frame[:, :4], 'scores': detections_per_frame[:, 4]})
75
+
76
+ os.makedirs(os.path.join(dst_root, video_rel_path), exist_ok=True)
77
+ with open(os.path.join(dst_root, video_rel_path, DETECTIONS_FILE_NAME), 'wb') as f:
78
+ pickle.dump(detections, f)
79
+
80
+
81
+ if __name__ == '__main__':
82
+ main()
dsfacedetector/__init__.py ADDED
File without changes
dsfacedetector/data/__init__.py ADDED
File without changes
dsfacedetector/data/config.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def test_base_transform(image, mean):
5
+ x = image.astype(np.float32)
6
+ x -= mean
7
+ x = x.astype(np.float32)
8
+ return x
9
+
10
+
11
+ class TestBaseTransform:
12
+ def __init__(self, mean):
13
+ self.mean = np.array(mean, dtype=np.float32)
14
+
15
+ def __call__(self, image):
16
+ return test_base_transform(image, self.mean)
17
+
18
+
19
+ widerface_640 = {
20
+ 'num_classes': 2,
21
+
22
+ 'feature_maps': [160, 80, 40, 20, 10, 5],
23
+ 'min_dim': 640,
24
+
25
+ 'steps': [4, 8, 16, 32, 64, 128], # stride
26
+
27
+ 'variance': [0.1, 0.2],
28
+ 'clip': True, # make default box in [0,1]
29
+ 'name': 'WIDERFace',
30
+ 'l2norm_scale': [10, 8, 5],
31
+ 'base': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M', 512, 512, 512],
32
+ 'extras': [256, 'S', 512, 128, 'S', 256],
33
+
34
+ 'mbox': [1, 1, 1, 1, 1, 1],
35
+ 'min_sizes': [16, 32, 64, 128, 256, 512],
36
+ 'max_sizes': [],
37
+ 'aspect_ratios': [[1.5], [1.5], [1.5], [1.5], [1.5], [1.5]], # [1,2] default 1
38
+
39
+ 'backbone': 'resnet152',
40
+ 'feature_pyramid_network': True,
41
+ 'bottom_up_path': False,
42
+ 'feature_enhance_module': True,
43
+ 'max_in_out': True,
44
+ 'focal_loss': False,
45
+ 'progressive_anchor': True,
46
+ 'refinedet': False,
47
+ 'max_out': False,
48
+ 'anchor_compensation': False,
49
+ 'data_anchor_sampling': False,
50
+
51
+ 'overlap_thresh': [0.4],
52
+ 'negpos_ratio': 3,
53
+ # test
54
+ 'nms_thresh': 0.3,
55
+ 'conf_thresh': 0.01,
56
+ 'num_thresh': 5000,
57
+ }
dsfacedetector/face_ssd_infer.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Source: https://github.com/vlad3996/FaceDetection-DSFD
2
+
3
+ import torch
4
+ import torchvision
5
+ import torch.nn as nn
6
+
7
+ from .data.config import TestBaseTransform, widerface_640 as cfg
8
+ from .layers import Detect, get_prior_boxes, FEM, pa_multibox, mio_module, upsample_product
9
+ from .utils import resize_image
10
+
11
+
12
+ class SSD(nn.Module):
13
+
14
+ def __init__(self, phase, nms_thresh=0.3, nms_conf_thresh=0.01):
15
+ super(SSD, self).__init__()
16
+ self.phase = phase
17
+ self.num_classes = 2
18
+ self.cfg = cfg
19
+
20
+ resnet = torchvision.models.resnet152(pretrained=False)
21
+
22
+ self.layer1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1)
23
+ self.layer2 = nn.Sequential(resnet.layer2)
24
+ self.layer3 = nn.Sequential(resnet.layer3)
25
+ self.layer4 = nn.Sequential(resnet.layer4)
26
+ self.layer5 = nn.Sequential(
27
+ *[nn.Conv2d(2048, 512, kernel_size=1),
28
+ nn.BatchNorm2d(512),
29
+ nn.ReLU(inplace=True),
30
+ nn.Conv2d(512, 512, kernel_size=3, padding=1, stride=2),
31
+ nn.BatchNorm2d(512),
32
+ nn.ReLU(inplace=True)]
33
+ )
34
+ self.layer6 = nn.Sequential(
35
+ *[nn.Conv2d(512, 128, kernel_size=1, ),
36
+ nn.BatchNorm2d(128),
37
+ nn.ReLU(inplace=True),
38
+ nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2),
39
+ nn.BatchNorm2d(256),
40
+ nn.ReLU(inplace=True)]
41
+ )
42
+
43
+ output_channels = [256, 512, 1024, 2048, 512, 256]
44
+
45
+ # FPN
46
+ fpn_in = output_channels
47
+
48
+ self.latlayer3 = nn.Conv2d(fpn_in[3], fpn_in[2], kernel_size=1, stride=1, padding=0)
49
+ self.latlayer2 = nn.Conv2d(fpn_in[2], fpn_in[1], kernel_size=1, stride=1, padding=0)
50
+ self.latlayer1 = nn.Conv2d(fpn_in[1], fpn_in[0], kernel_size=1, stride=1, padding=0)
51
+
52
+ self.smooth3 = nn.Conv2d(fpn_in[2], fpn_in[2], kernel_size=1, stride=1, padding=0)
53
+ self.smooth2 = nn.Conv2d(fpn_in[1], fpn_in[1], kernel_size=1, stride=1, padding=0)
54
+ self.smooth1 = nn.Conv2d(fpn_in[0], fpn_in[0], kernel_size=1, stride=1, padding=0)
55
+
56
+ # FEM
57
+ cpm_in = output_channels
58
+
59
+ self.cpm3_3 = FEM(cpm_in[0])
60
+ self.cpm4_3 = FEM(cpm_in[1])
61
+ self.cpm5_3 = FEM(cpm_in[2])
62
+ self.cpm7 = FEM(cpm_in[3])
63
+ self.cpm6_2 = FEM(cpm_in[4])
64
+ self.cpm7_2 = FEM(cpm_in[5])
65
+
66
+ # head
67
+ head = pa_multibox(output_channels)
68
+ self.loc = nn.ModuleList(head[0])
69
+ self.conf = nn.ModuleList(head[1])
70
+
71
+ self.softmax = nn.Softmax(dim=-1)
72
+
73
+ if self.phase != 'onnx_export':
74
+ self.detect = Detect(self.num_classes, 0, cfg['num_thresh'], nms_conf_thresh, nms_thresh,
75
+ cfg['variance'])
76
+ self.last_image_size = None
77
+ self.last_feature_maps = None
78
+
79
+ if self.phase == 'test':
80
+ self.test_transform = TestBaseTransform((104, 117, 123))
81
+
82
+ def forward(self, x):
83
+
84
+ image_size = [x.shape[2], x.shape[3]]
85
+ loc = list()
86
+ conf = list()
87
+
88
+ conv3_3_x = self.layer1(x)
89
+ conv4_3_x = self.layer2(conv3_3_x)
90
+ conv5_3_x = self.layer3(conv4_3_x)
91
+ fc7_x = self.layer4(conv5_3_x)
92
+ conv6_2_x = self.layer5(fc7_x)
93
+ conv7_2_x = self.layer6(conv6_2_x)
94
+
95
+ lfpn3 = upsample_product(self.latlayer3(fc7_x), self.smooth3(conv5_3_x))
96
+ lfpn2 = upsample_product(self.latlayer2(lfpn3), self.smooth2(conv4_3_x))
97
+ lfpn1 = upsample_product(self.latlayer1(lfpn2), self.smooth1(conv3_3_x))
98
+
99
+ conv5_3_x = lfpn3
100
+ conv4_3_x = lfpn2
101
+ conv3_3_x = lfpn1
102
+
103
+ sources = [conv3_3_x, conv4_3_x, conv5_3_x, fc7_x, conv6_2_x, conv7_2_x]
104
+
105
+ sources[0] = self.cpm3_3(sources[0])
106
+ sources[1] = self.cpm4_3(sources[1])
107
+ sources[2] = self.cpm5_3(sources[2])
108
+ sources[3] = self.cpm7(sources[3])
109
+ sources[4] = self.cpm6_2(sources[4])
110
+ sources[5] = self.cpm7_2(sources[5])
111
+
112
+ # apply multibox head to source layers
113
+ featuremap_size = []
114
+ for (x, l, c) in zip(sources, self.loc, self.conf):
115
+ featuremap_size.append([x.shape[2], x.shape[3]])
116
+ loc.append(l(x).permute(0, 2, 3, 1).contiguous())
117
+ len_conf = len(conf)
118
+ cls = mio_module(c(x), len_conf)
119
+ conf.append(cls.permute(0, 2, 3, 1).contiguous())
120
+
121
+ face_loc = torch.cat([o[:, :, :, :4].contiguous().view(o.size(0), -1) for o in loc], 1)
122
+ face_loc = face_loc.view(face_loc.size(0), -1, 4)
123
+ face_conf = torch.cat([o[:, :, :, :2].contiguous().view(o.size(0), -1) for o in conf], 1)
124
+ face_conf = self.softmax(face_conf.view(face_conf.size(0), -1, self.num_classes))
125
+
126
+ if self.phase != 'onnx_export':
127
+
128
+ if self.last_image_size is None or self.last_image_size != image_size or self.last_feature_maps != featuremap_size:
129
+ self.priors = get_prior_boxes(self.cfg, featuremap_size, image_size).to(face_loc.device)
130
+ self.last_image_size = image_size
131
+ self.last_feature_maps = featuremap_size
132
+ with torch.no_grad():
133
+ output = self.detect(face_loc, face_conf, self.priors)
134
+ else:
135
+ output = torch.cat((face_loc, face_conf), 2)
136
+ return output
137
+
138
+ def detect_on_image(self, source_image, target_size, device, is_pad=False, keep_thresh=0.3):
139
+
140
+ image, shift_h_scaled, shift_w_scaled, scale = resize_image(source_image, target_size, is_pad=is_pad)
141
+
142
+ x = torch.from_numpy(self.test_transform(image)).permute(2, 0, 1).to(device)
143
+ x.unsqueeze_(0)
144
+
145
+ detections = self.forward(x).cpu().numpy()
146
+
147
+ scores = detections[0, 1, :, 0]
148
+ keep_idxs = scores > keep_thresh # find keeping indexes
149
+ detections = detections[0, 1, keep_idxs, :] # select detections over threshold
150
+ detections = detections[:, [1, 2, 3, 4, 0]] # reorder
151
+
152
+ detections[:, [0, 2]] -= shift_w_scaled # 0 or pad percent from left corner
153
+ detections[:, [1, 3]] -= shift_h_scaled # 0 or pad percent from top
154
+ detections[:, :4] *= scale
155
+
156
+ return detections
dsfacedetector/layers/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .detection import Detect
2
+ from .prior_box import PriorBox, get_prior_boxes
3
+ from .modules import FEM, pa_multibox, mio_module, upsample_product
dsfacedetector/layers/detection.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class Detect(nn.Module):
7
+ """At test time, Detect is the final layer of SSD. Decode location preds,
8
+ apply non-maximum suppression to location predictions based on conf
9
+ scores and threshold to a top_k number of output predictions for both
10
+ confidence score and locations.
11
+ """
12
+
13
+ def __init__(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh, variance=(0.1, 0.2)):
14
+ super(Detect, self).__init__()
15
+ self.num_classes = num_classes
16
+ self.background_label = bkg_label
17
+ self.top_k = top_k
18
+ # Parameters used in nms.
19
+ self.nms_thresh = nms_thresh
20
+ if nms_thresh <= 0:
21
+ raise ValueError('nms_threshold must be non negative.')
22
+ self.conf_thresh = conf_thresh
23
+ self.variance = variance
24
+
25
+ def forward(self, loc_data, conf_data, prior_data):
26
+ """
27
+ Args:
28
+ loc_data: (tensor) Loc preds from loc layers
29
+ Shape: [batch,num_priors*4]
30
+ conf_data: (tensor) Shape: Conf preds from conf layers
31
+ Shape: [batch*num_priors,num_classes]
32
+ prior_data: (tensor) Prior boxes and variances from priorbox layers
33
+ Shape: [1,num_priors,4]
34
+ """
35
+ num = loc_data.size(0) # batch size
36
+ num_priors = prior_data.size(0)
37
+
38
+ output = torch.zeros(num, self.num_classes, self.top_k, 5)
39
+ conf_preds = conf_data.view(num, num_priors, self.num_classes).transpose(2, 1)
40
+
41
+ # Decode predictions into bboxes.
42
+ for i in range(num):
43
+ default = prior_data
44
+ decoded_boxes = decode(loc_data[i], default, self.variance)
45
+ # For each class, perform nms
46
+ conf_scores = conf_preds[i].clone()
47
+
48
+ for cl in range(1, self.num_classes):
49
+ c_mask = conf_scores[cl].gt(self.conf_thresh)
50
+ scores = conf_scores[cl][c_mask]
51
+ if scores.dim() == 0 or scores.size(0) == 0:
52
+ continue
53
+ l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)
54
+ boxes = decoded_boxes[l_mask].view(-1, 4)
55
+ # idx of highest scoring and non-overlapping boxes per class
56
+ ids, count = nms(boxes, scores, self.nms_thresh, self.top_k)
57
+ output[i, cl, :count] = \
58
+ torch.cat((scores[ids[:count]].unsqueeze(1),
59
+ boxes[ids[:count]]), 1)
60
+ flt = output.contiguous().view(num, -1, 5)
61
+ _, idx = flt[:, :, 0].sort(1, descending=True)
62
+ _, rank = idx.sort(1)
63
+ flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0)
64
+ return output
65
+
66
+
67
+ # Adapted from https://github.com/Hakuyume/chainer-ssd
68
+ def decode(loc, priors, variances):
69
+ """Decode locations from predictions using priors to undo
70
+ the encoding we did for offset regression at train time.
71
+ Args:
72
+ loc (tensor): location predictions for loc layers,
73
+ Shape: [num_priors,4]
74
+ priors (tensor): Prior boxes in center-offset form.
75
+ Shape: [num_priors,4].
76
+ variances: (list[float]) Variances of priorboxes
77
+ Return:
78
+ decoded bounding box predictions
79
+ """
80
+
81
+ boxes = torch.cat((
82
+ priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
83
+ priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
84
+ boxes[:, :2] -= boxes[:, 2:] / 2
85
+ boxes[:, 2:] += boxes[:, :2]
86
+ # (cx,cy,w,h)->(x0,y0,x1,y1)
87
+ return boxes
88
+
89
+
90
+ # Original author: Francisco Massa:
91
+ # https://github.com/fmassa/object-detection.torch
92
+ # Ported to PyTorch by Max deGroot (02/01/2017)
93
+ def nms(boxes, scores, overlap=0.5, top_k=200):
94
+ """Apply non-maximum suppression at test time to avoid detecting too many
95
+ overlapping bounding boxes for a given object.
96
+ Args:
97
+ boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
98
+ scores: (tensor) The class predscores for the img, Shape:[num_priors].
99
+ overlap: (float) The overlap thresh for suppressing unnecessary boxes.
100
+ top_k: (int) The Maximum number of box preds to consider.
101
+ Return:
102
+ The indices of the kept boxes with respect to num_priors.
103
+ """
104
+
105
+ keep = scores.new(scores.size(0)).zero_().long()
106
+ if boxes.numel() == 0:
107
+ return keep
108
+ x1 = boxes[:, 0]
109
+ y1 = boxes[:, 1]
110
+ x2 = boxes[:, 2]
111
+ y2 = boxes[:, 3]
112
+ area = torch.mul(x2 - x1, y2 - y1)
113
+ v, idx = scores.sort(0) # sort in ascending order
114
+ # I = I[v >= 0.01]
115
+ idx = idx[-top_k:] # indices of the top-k largest vals
116
+ xx1 = boxes.new()
117
+ yy1 = boxes.new()
118
+ xx2 = boxes.new()
119
+ yy2 = boxes.new()
120
+ w = boxes.new()
121
+ h = boxes.new()
122
+
123
+ # keep = torch.Tensor()
124
+ count = 0
125
+ while idx.numel() > 0:
126
+ i = idx[-1] # index of current largest val
127
+ # keep.append(i)
128
+ keep[count] = i
129
+ count += 1
130
+ if idx.size(0) == 1:
131
+ break
132
+ idx = idx[:-1] # remove kept element from view
133
+ # load bboxes of next highest vals
134
+ torch.index_select(x1, 0, idx, out=xx1)
135
+ torch.index_select(y1, 0, idx, out=yy1)
136
+ torch.index_select(x2, 0, idx, out=xx2)
137
+ torch.index_select(y2, 0, idx, out=yy2)
138
+ # store element-wise max with next highest score
139
+ xx1 = torch.clamp(xx1, min=x1[i])
140
+ yy1 = torch.clamp(yy1, min=y1[i])
141
+ xx2 = torch.clamp(xx2, max=x2[i])
142
+ yy2 = torch.clamp(yy2, max=y2[i])
143
+ w.resize_as_(xx2)
144
+ h.resize_as_(yy2)
145
+ w = xx2 - xx1
146
+ h = yy2 - yy1
147
+ # check sizes of xx1 and xx2.. after each iteration
148
+ w = torch.clamp(w, min=0.0)
149
+ h = torch.clamp(h, min=0.0)
150
+ inter = w * h
151
+ # IoU = i / (area(a) + area(b) - i)
152
+ rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
153
+ union = (rem_areas - inter) + area[i]
154
+ IoU = inter / union # store result in iou
155
+ # keep only elements with an IoU <= overlap
156
+ idx = idx[IoU.le(overlap)]
157
+ return keep, count
dsfacedetector/layers/modules.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class DeepHeadModule(nn.Module):
7
+ def __init__(self, input_channels, output_channels):
8
+ super(DeepHeadModule, self).__init__()
9
+ self._input_channels = input_channels
10
+ self._output_channels = output_channels
11
+ self._mid_channels = min(self._input_channels, 256)
12
+
13
+ self.conv1 = nn.Conv2d(self._input_channels, self._mid_channels, kernel_size=3, dilation=1, stride=1, padding=1)
14
+ self.conv2 = nn.Conv2d(self._mid_channels, self._mid_channels, kernel_size=3, dilation=1, stride=1, padding=1)
15
+ self.conv3 = nn.Conv2d(self._mid_channels, self._mid_channels, kernel_size=3, dilation=1, stride=1, padding=1)
16
+ self.conv4 = nn.Conv2d(self._mid_channels, self._output_channels, kernel_size=1, dilation=1, stride=1,
17
+ padding=0)
18
+
19
+ def forward(self, x):
20
+ return self.conv4(
21
+ F.relu(self.conv3(F.relu(self.conv2(F.relu(self.conv1(x), inplace=True)), inplace=True)), inplace=True))
22
+
23
+
24
+ class FEM(nn.Module):
25
+ def __init__(self, channel_size):
26
+ super(FEM, self).__init__()
27
+ self.cs = channel_size
28
+ self.cpm1 = nn.Conv2d(self.cs, 256, kernel_size=3, dilation=1, stride=1, padding=1)
29
+ self.cpm2 = nn.Conv2d(self.cs, 256, kernel_size=3, dilation=2, stride=1, padding=2)
30
+ self.cpm3 = nn.Conv2d(256, 128, kernel_size=3, dilation=1, stride=1, padding=1)
31
+ self.cpm4 = nn.Conv2d(256, 128, kernel_size=3, dilation=2, stride=1, padding=2)
32
+ self.cpm5 = nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1)
33
+
34
+ def forward(self, x):
35
+ x1_1 = F.relu(self.cpm1(x), inplace=True)
36
+ x1_2 = F.relu(self.cpm2(x), inplace=True)
37
+ x2_1 = F.relu(self.cpm3(x1_2), inplace=True)
38
+ x2_2 = F.relu(self.cpm4(x1_2), inplace=True)
39
+ x3_1 = F.relu(self.cpm5(x2_2), inplace=True)
40
+ return torch.cat((x1_1, x2_1, x3_1), 1)
41
+
42
+
43
+ def upsample_product(x, y):
44
+ '''Upsample and add two feature maps.
45
+ Args:
46
+ x: (Variable) top feature map to be upsampled.
47
+ y: (Variable) lateral feature map.
48
+ Returns:
49
+ (Variable) added feature map.
50
+ Note in PyTorch, when input size is odd, the upsampled feature map
51
+ with `F.upsample(..., scale_factor=2, mode='nearest')`
52
+ maybe not equal to the lateral feature map size.
53
+ e.g.
54
+ original input size: [N,_,15,15] ->
55
+ conv2d feature map size: [N,_,8,8] ->
56
+ upsampled feature map size: [N,_,16,16]
57
+ So we choose bilinear upsample which supports arbitrary output sizes.
58
+ '''
59
+ _, _, H, W = y.size()
60
+
61
+ # FOR ONNX CONVERSION
62
+ # return F.interpolate(x, scale_factor=2, mode='nearest') * y
63
+ return F.interpolate(x, size=(int(H), int(W)), mode='bilinear', align_corners=False) * y
64
+
65
+
66
+ def pa_multibox(output_channels):
67
+ loc_layers = []
68
+ conf_layers = []
69
+ for k, v in enumerate(output_channels):
70
+ if k == 0:
71
+ loc_output = 4
72
+ conf_output = 2
73
+ elif k == 1:
74
+ loc_output = 8
75
+ conf_output = 4
76
+ else:
77
+ loc_output = 12
78
+ conf_output = 6
79
+ loc_layers += [DeepHeadModule(512, loc_output)]
80
+ conf_layers += [DeepHeadModule(512, (2 + conf_output))]
81
+ return (loc_layers, conf_layers)
82
+
83
+
84
+ def mio_module(each_mmbox, len_conf, your_mind_state='peasant'):
85
+ # chunk = torch.split(each_mmbox, 1, 1) - !!!!! failed to export on PyTorch v1.0.1 (ONNX version 1.3)
86
+ chunk = torch.chunk(each_mmbox, int(each_mmbox.shape[1]), 1)
87
+
88
+ # some hacks for ONNX and Inference Engine export
89
+ if your_mind_state == 'peasant':
90
+ bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
91
+ elif your_mind_state == 'advanced':
92
+ bmax = torch.max(each_mmbox[:, :3], 1)[0].unsqueeze(0)
93
+ else: # supermind
94
+ bmax = torch.nn.functional.max_pool3d(each_mmbox[:, :3], kernel_size=(3, 1, 1))
95
+
96
+ cls = (torch.cat((bmax, chunk[3]), dim=1) if len_conf == 0 else torch.cat((chunk[3], bmax), dim=1))
97
+ cls = torch.cat((cls, *list(chunk[4:])), dim=1)
98
+ return cls
dsfacedetector/layers/prior_box.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ from math import sqrt as sqrt
3
+ import torch
4
+
5
+
6
+ class PriorBox(object):
7
+ """Compute priorbox coordinates in center-offset form for each source
8
+ feature map.
9
+ """
10
+
11
+ def __init__(self, cfg, min_size, max_size):
12
+ super(PriorBox, self).__init__()
13
+ self.image_size = cfg['min_dim']
14
+ self.feature_maps = cfg['feature_maps']
15
+
16
+ self.variance = cfg['variance'] or [0.1]
17
+ self.min_sizes = min_size
18
+ self.max_sizes = max_size
19
+ self.steps = cfg['steps']
20
+ self.aspect_ratios = cfg['aspect_ratios']
21
+ self.clip = cfg['clip']
22
+
23
+ for v in self.variance:
24
+ if v <= 0:
25
+ raise ValueError('Variances must be greater than 0')
26
+
27
+ def forward(self):
28
+
29
+ mean = []
30
+
31
+ if len(self.min_sizes) == 5:
32
+ self.feature_maps = self.feature_maps[1:]
33
+ self.steps = self.steps[1:]
34
+ if len(self.min_sizes) == 4:
35
+ self.feature_maps = self.feature_maps[2:]
36
+ self.steps = self.steps[2:]
37
+
38
+ for k, f in enumerate(self.feature_maps):
39
+ # for i, j in product(range(f), repeat=2):
40
+ for i in range(f[0]):
41
+ for j in range(f[1]):
42
+ # f_k = self.image_size / self.steps[k]
43
+ f_k_i = self.image_size[0] / self.steps[k]
44
+ f_k_j = self.image_size[1] / self.steps[k]
45
+ # unit center x,y
46
+ cx = (j + 0.5) / f_k_j
47
+ cy = (i + 0.5) / f_k_i
48
+ # aspect_ratio: 1
49
+ # rel size: min_size
50
+ s_k_i = self.min_sizes[k] / self.image_size[1]
51
+ s_k_j = self.min_sizes[k] / self.image_size[0]
52
+ # swordli@tencent
53
+ if len(self.aspect_ratios[0]) == 0:
54
+ mean += [cx, cy, s_k_i, s_k_j]
55
+
56
+ # aspect_ratio: 1
57
+ # rel size: sqrt(s_k * s_(k+1))
58
+ # s_k_prime = sqrt(s_k * (self.max_sizes[k]/self.image_size))
59
+ if len(self.max_sizes) == len(self.min_sizes):
60
+ s_k_prime_i = sqrt(s_k_i * (self.max_sizes[k] / self.image_size[1]))
61
+ s_k_prime_j = sqrt(s_k_j * (self.max_sizes[k] / self.image_size[0]))
62
+ mean += [cx, cy, s_k_prime_i, s_k_prime_j]
63
+ # rest of aspect ratios
64
+ for ar in self.aspect_ratios[k]:
65
+ if len(self.max_sizes) == len(self.min_sizes):
66
+ mean += [cx, cy, s_k_prime_i / sqrt(ar), s_k_prime_j * sqrt(ar)]
67
+ mean += [cx, cy, s_k_i / sqrt(ar), s_k_j * sqrt(ar)]
68
+
69
+ # back to torch land
70
+ output = torch.Tensor(mean).view(-1, 4)
71
+ if self.clip:
72
+ output.clamp_(max=1, min=0)
73
+ return output
74
+
75
+
76
+ def get_prior_boxes(cfg, feature_maps, image_size):
77
+
78
+ # number of priors for feature map location (either 4 or 6)
79
+ variance = cfg['variance'] or [0.1]
80
+ min_sizes = cfg['min_sizes']
81
+ max_sizes = cfg['max_sizes']
82
+ steps = cfg['steps']
83
+ aspect_ratios = cfg['aspect_ratios']
84
+ clip = cfg['clip']
85
+ for v in variance:
86
+ if v <= 0:
87
+ raise ValueError('Variances must be greater than 0')
88
+
89
+ mean = []
90
+
91
+ if len(min_sizes) == 5:
92
+ feature_maps = feature_maps[1:]
93
+ steps = steps[1:]
94
+ if len(min_sizes) == 4:
95
+ feature_maps = feature_maps[2:]
96
+ steps = steps[2:]
97
+
98
+ for k, f in enumerate(feature_maps):
99
+ # for i, j in product(range(f), repeat=2):
100
+ for i in range(f[0]):
101
+ for j in range(f[1]):
102
+ # f_k = image_size / steps[k]
103
+ f_k_i = image_size[0] / steps[k]
104
+ f_k_j = image_size[1] / steps[k]
105
+ # unit center x,y
106
+ cx = (j + 0.5) / f_k_j
107
+ cy = (i + 0.5) / f_k_i
108
+ # aspect_ratio: 1
109
+ # rel size: min_size
110
+ s_k_i = min_sizes[k] / image_size[1]
111
+ s_k_j = min_sizes[k] / image_size[0]
112
+ # swordli@tencent
113
+ if len(aspect_ratios[0]) == 0:
114
+ mean += [cx, cy, s_k_i, s_k_j]
115
+
116
+ # aspect_ratio: 1
117
+ # rel size: sqrt(s_k * s_(k+1))
118
+ # s_k_prime = sqrt(s_k * (max_sizes[k]/image_size))
119
+ if len(max_sizes) == len(min_sizes):
120
+ s_k_prime_i = sqrt(s_k_i * (max_sizes[k] / image_size[1]))
121
+ s_k_prime_j = sqrt(s_k_j * (max_sizes[k] / image_size[0]))
122
+ mean += [cx, cy, s_k_prime_i, s_k_prime_j]
123
+ # rest of aspect ratios
124
+ for ar in aspect_ratios[k]:
125
+ if len(max_sizes) == len(min_sizes):
126
+ mean += [cx, cy, s_k_prime_i / sqrt(ar), s_k_prime_j * sqrt(ar)]
127
+ mean += [cx, cy, s_k_i / sqrt(ar), s_k_j * sqrt(ar)]
128
+
129
+ # back to torch land
130
+ output = torch.Tensor(mean).view(-1, 4)
131
+ if clip:
132
+ output.clamp_(max=1, min=0)
133
+ return output
dsfacedetector/utils.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+
5
+
6
+ def vis_detections(im, dets, thresh=0.5, show_text=True):
7
+ """Draw detected bounding boxes."""
8
+ class_name = 'face'
9
+ inds = np.where(dets[:, -1] >= thresh)[0] if dets is not None else []
10
+ if len(inds) == 0:
11
+ return
12
+ im = im[:, :, (2, 1, 0)]
13
+ fig, ax = plt.subplots(figsize=(12, 12))
14
+ ax.imshow(im, aspect='equal')
15
+ for i in inds:
16
+ bbox = dets[i, :4]
17
+ score = dets[i, -1]
18
+ ax.add_patch(
19
+ plt.Rectangle((bbox[0], bbox[1]),
20
+ bbox[2] - bbox[0],
21
+ bbox[3] - bbox[1], fill=False,
22
+ edgecolor='red', linewidth=2.5)
23
+ )
24
+ if show_text:
25
+ ax.text(bbox[0], bbox[1] - 5,
26
+ '{:s} {:.3f}'.format(class_name, score),
27
+ bbox=dict(facecolor='blue', alpha=0.5),
28
+ fontsize=10, color='white')
29
+ ax.set_title(('{} detections with '
30
+ 'p({} | box) >= {:.1f}').format(class_name, class_name,
31
+ thresh),
32
+ fontsize=10)
33
+ plt.axis('off')
34
+ plt.tight_layout()
35
+ plt.savefig('out.png')
36
+ plt.show()
37
+
38
+
39
+ def bbox_vote(det):
40
+ order = det[:, 4].ravel().argsort()[::-1]
41
+ det = det[order, :]
42
+ dets = None
43
+ while det.shape[0] > 0:
44
+ # IOU
45
+ area = (det[:, 2] - det[:, 0] + 1) * (det[:, 3] - det[:, 1] + 1)
46
+ xx1 = np.maximum(det[0, 0], det[:, 0])
47
+ yy1 = np.maximum(det[0, 1], det[:, 1])
48
+ xx2 = np.minimum(det[0, 2], det[:, 2])
49
+ yy2 = np.minimum(det[0, 3], det[:, 3])
50
+ w = np.maximum(0.0, xx2 - xx1 + 1)
51
+ h = np.maximum(0.0, yy2 - yy1 + 1)
52
+ inter = w * h
53
+ o = inter / (area[0] + area[:] - inter)
54
+ # get needed merge det and delete these det
55
+ merge_index = np.where(o >= 0.3)[0]
56
+ det_accu = det[merge_index, :]
57
+ det = np.delete(det, merge_index, 0)
58
+ if merge_index.shape[0] <= 1:
59
+ continue
60
+ det_accu[:, 0:4] = det_accu[:, 0:4] * np.tile(det_accu[:, -1:], (1, 4))
61
+ max_score = np.max(det_accu[:, 4])
62
+ det_accu_sum = np.zeros((1, 5))
63
+ det_accu_sum[:, 0:4] = np.sum(det_accu[:, 0:4], axis=0) / np.sum(det_accu[:, -1:])
64
+ det_accu_sum[:, 4] = max_score
65
+ try:
66
+ dets = np.row_stack((dets, det_accu_sum))
67
+ except:
68
+ dets = det_accu_sum
69
+ if dets is not None:
70
+ dets = dets[0:750, :]
71
+ return dets
72
+
73
+
74
+ def add_borders(curr_img, target_shape=(224, 224), fill_type=0):
75
+ curr_h, curr_w = curr_img.shape[0:2]
76
+ shift_h = max(target_shape[0] - curr_h, 0)
77
+ shift_w = max(target_shape[1] - curr_w, 0)
78
+
79
+ image = cv2.copyMakeBorder(curr_img, shift_h // 2, (shift_h + 1) // 2, shift_w // 2, (shift_w + 1) // 2, fill_type)
80
+ return image, shift_h, shift_w
81
+
82
+
83
+ def resize_image(image, target_size, resize_factor=None, is_pad=True, interpolation=3):
84
+ curr_image_size = image.shape[0:2]
85
+
86
+ if resize_factor is None and is_pad:
87
+ resize_factor = min(target_size[0] / curr_image_size[0], target_size[1] / curr_image_size[1])
88
+ elif resize_factor is None and not is_pad:
89
+ resize_factor = np.sqrt((target_size[0] * target_size[1]) / (curr_image_size[0] * curr_image_size[1]))
90
+
91
+ image = cv2.resize(image, None, None, fx=resize_factor, fy=resize_factor, interpolation=interpolation)
92
+
93
+ if is_pad:
94
+ image, shift_h, shift_w = add_borders(image, target_size)
95
+ else:
96
+ shift_h = shift_w = 0
97
+
98
+ scale = np.array([image.shape[1]/resize_factor, image.shape[0]/resize_factor,
99
+ image.shape[1]/resize_factor, image.shape[0]/resize_factor])
100
+
101
+ return image, shift_h/image.shape[0]/2, shift_w/image.shape[1]/2, scale
external_data/convert_tf_to_pt.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Source: https://github.com/lukemelas/EfficientNet-PyTorch
2
+
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ import torch
6
+
7
+ def load_param(checkpoint_file, conversion_table, model_name):
8
+ """
9
+ Load parameters according to conversion_table.
10
+
11
+ Args:
12
+ checkpoint_file (string): pretrained checkpoint model file in tensorflow
13
+ conversion_table (dict): { pytorch tensor in a model : checkpoint variable name }
14
+ """
15
+ for pyt_param, tf_param_name in conversion_table.items():
16
+ tf_param_name = str(model_name) + '/' + tf_param_name
17
+ tf_param = tf.train.load_variable(checkpoint_file, tf_param_name)
18
+ if 'conv' in tf_param_name and 'kernel' in tf_param_name:
19
+ tf_param = np.transpose(tf_param, (3, 2, 0, 1))
20
+ if 'depthwise' in tf_param_name:
21
+ tf_param = np.transpose(tf_param, (1, 0, 2, 3))
22
+ elif tf_param_name.endswith('kernel'): # for weight(kernel), we should do transpose
23
+ tf_param = np.transpose(tf_param)
24
+ assert pyt_param.size() == tf_param.shape, \
25
+ 'Dim Mismatch: %s vs %s ; %s' % (tuple(pyt_param.size()), tf_param.shape, tf_param_name)
26
+ pyt_param.data = torch.from_numpy(tf_param)
27
+
28
+
29
+ def load_efficientnet(model, checkpoint_file, model_name):
30
+ """
31
+ Load PyTorch EfficientNet from TensorFlow checkpoint file
32
+ """
33
+
34
+ # This will store the enire conversion table
35
+ conversion_table = {}
36
+ merge = lambda dict1, dict2: {**dict1, **dict2}
37
+
38
+ # All the weights not in the conv blocks
39
+ conversion_table_for_weights_outside_blocks = {
40
+ model._conv_stem.weight: 'stem/conv2d/kernel', # [3, 3, 3, 32]),
41
+ model._bn0.bias: 'stem/tpu_batch_normalization/beta', # [32]),
42
+ model._bn0.weight: 'stem/tpu_batch_normalization/gamma', # [32]),
43
+ model._bn0.running_mean: 'stem/tpu_batch_normalization/moving_mean', # [32]),
44
+ model._bn0.running_var: 'stem/tpu_batch_normalization/moving_variance', # [32]),
45
+ model._conv_head.weight: 'head/conv2d/kernel', # [1, 1, 320, 1280]),
46
+ model._bn1.bias: 'head/tpu_batch_normalization/beta', # [1280]),
47
+ model._bn1.weight: 'head/tpu_batch_normalization/gamma', # [1280]),
48
+ model._bn1.running_mean: 'head/tpu_batch_normalization/moving_mean', # [32]),
49
+ model._bn1.running_var: 'head/tpu_batch_normalization/moving_variance', # [32]),
50
+ model._fc.bias: 'head/dense/bias', # [1000]),
51
+ model._fc.weight: 'head/dense/kernel', # [1280, 1000]),
52
+ }
53
+ conversion_table = merge(conversion_table, conversion_table_for_weights_outside_blocks)
54
+
55
+ # The first conv block is special because it does not have _expand_conv
56
+ conversion_table_for_first_block = {
57
+ model._blocks[0]._project_conv.weight: 'blocks_0/conv2d/kernel', # 1, 1, 32, 16]),
58
+ model._blocks[0]._depthwise_conv.weight: 'blocks_0/depthwise_conv2d/depthwise_kernel', # [3, 3, 32, 1]),
59
+ model._blocks[0]._se_reduce.bias: 'blocks_0/se/conv2d/bias', # , [8]),
60
+ model._blocks[0]._se_reduce.weight: 'blocks_0/se/conv2d/kernel', # , [1, 1, 32, 8]),
61
+ model._blocks[0]._se_expand.bias: 'blocks_0/se/conv2d_1/bias', # , [32]),
62
+ model._blocks[0]._se_expand.weight: 'blocks_0/se/conv2d_1/kernel', # , [1, 1, 8, 32]),
63
+ model._blocks[0]._bn1.bias: 'blocks_0/tpu_batch_normalization/beta', # [32]),
64
+ model._blocks[0]._bn1.weight: 'blocks_0/tpu_batch_normalization/gamma', # [32]),
65
+ model._blocks[0]._bn1.running_mean: 'blocks_0/tpu_batch_normalization/moving_mean',
66
+ model._blocks[0]._bn1.running_var: 'blocks_0/tpu_batch_normalization/moving_variance',
67
+ model._blocks[0]._bn2.bias: 'blocks_0/tpu_batch_normalization_1/beta', # [16]),
68
+ model._blocks[0]._bn2.weight: 'blocks_0/tpu_batch_normalization_1/gamma', # [16]),
69
+ model._blocks[0]._bn2.running_mean: 'blocks_0/tpu_batch_normalization_1/moving_mean',
70
+ model._blocks[0]._bn2.running_var: 'blocks_0/tpu_batch_normalization_1/moving_variance',
71
+ }
72
+ conversion_table = merge(conversion_table, conversion_table_for_first_block)
73
+
74
+ # Conv blocks
75
+ for i in range(len(model._blocks)):
76
+
77
+ is_first_block = '_expand_conv.weight' not in [n for n, p in model._blocks[i].named_parameters()]
78
+
79
+ if is_first_block:
80
+ conversion_table_block = {
81
+ model._blocks[i]._project_conv.weight: 'blocks_' + str(i) + '/conv2d/kernel', # 1, 1, 32, 16]),
82
+ model._blocks[i]._depthwise_conv.weight: 'blocks_' + str(i) + '/depthwise_conv2d/depthwise_kernel',
83
+ # [3, 3, 32, 1]),
84
+ model._blocks[i]._se_reduce.bias: 'blocks_' + str(i) + '/se/conv2d/bias', # , [8]),
85
+ model._blocks[i]._se_reduce.weight: 'blocks_' + str(i) + '/se/conv2d/kernel', # , [1, 1, 32, 8]),
86
+ model._blocks[i]._se_expand.bias: 'blocks_' + str(i) + '/se/conv2d_1/bias', # , [32]),
87
+ model._blocks[i]._se_expand.weight: 'blocks_' + str(i) + '/se/conv2d_1/kernel', # , [1, 1, 8, 32]),
88
+ model._blocks[i]._bn1.bias: 'blocks_' + str(i) + '/tpu_batch_normalization/beta', # [32]),
89
+ model._blocks[i]._bn1.weight: 'blocks_' + str(i) + '/tpu_batch_normalization/gamma', # [32]),
90
+ model._blocks[i]._bn1.running_mean: 'blocks_' + str(i) + '/tpu_batch_normalization/moving_mean',
91
+ model._blocks[i]._bn1.running_var: 'blocks_' + str(i) + '/tpu_batch_normalization/moving_variance',
92
+ model._blocks[i]._bn2.bias: 'blocks_' + str(i) + '/tpu_batch_normalization_1/beta', # [16]),
93
+ model._blocks[i]._bn2.weight: 'blocks_' + str(i) + '/tpu_batch_normalization_1/gamma', # [16]),
94
+ model._blocks[i]._bn2.running_mean: 'blocks_' + str(i) + '/tpu_batch_normalization_1/moving_mean',
95
+ model._blocks[i]._bn2.running_var: 'blocks_' + str(i) + '/tpu_batch_normalization_1/moving_variance',
96
+ }
97
+
98
+ else:
99
+ conversion_table_block = {
100
+ model._blocks[i]._expand_conv.weight: 'blocks_' + str(i) + '/conv2d/kernel',
101
+ model._blocks[i]._project_conv.weight: 'blocks_' + str(i) + '/conv2d_1/kernel',
102
+ model._blocks[i]._depthwise_conv.weight: 'blocks_' + str(i) + '/depthwise_conv2d/depthwise_kernel',
103
+ model._blocks[i]._se_reduce.bias: 'blocks_' + str(i) + '/se/conv2d/bias',
104
+ model._blocks[i]._se_reduce.weight: 'blocks_' + str(i) + '/se/conv2d/kernel',
105
+ model._blocks[i]._se_expand.bias: 'blocks_' + str(i) + '/se/conv2d_1/bias',
106
+ model._blocks[i]._se_expand.weight: 'blocks_' + str(i) + '/se/conv2d_1/kernel',
107
+ model._blocks[i]._bn0.bias: 'blocks_' + str(i) + '/tpu_batch_normalization/beta',
108
+ model._blocks[i]._bn0.weight: 'blocks_' + str(i) + '/tpu_batch_normalization/gamma',
109
+ model._blocks[i]._bn0.running_mean: 'blocks_' + str(i) + '/tpu_batch_normalization/moving_mean',
110
+ model._blocks[i]._bn0.running_var: 'blocks_' + str(i) + '/tpu_batch_normalization/moving_variance',
111
+ model._blocks[i]._bn1.bias: 'blocks_' + str(i) + '/tpu_batch_normalization_1/beta',
112
+ model._blocks[i]._bn1.weight: 'blocks_' + str(i) + '/tpu_batch_normalization_1/gamma',
113
+ model._blocks[i]._bn1.running_mean: 'blocks_' + str(i) + '/tpu_batch_normalization_1/moving_mean',
114
+ model._blocks[i]._bn1.running_var: 'blocks_' + str(i) + '/tpu_batch_normalization_1/moving_variance',
115
+ model._blocks[i]._bn2.bias: 'blocks_' + str(i) + '/tpu_batch_normalization_2/beta',
116
+ model._blocks[i]._bn2.weight: 'blocks_' + str(i) + '/tpu_batch_normalization_2/gamma',
117
+ model._blocks[i]._bn2.running_mean: 'blocks_' + str(i) + '/tpu_batch_normalization_2/moving_mean',
118
+ model._blocks[i]._bn2.running_var: 'blocks_' + str(i) + '/tpu_batch_normalization_2/moving_variance',
119
+ }
120
+
121
+ conversion_table = merge(conversion_table, conversion_table_block)
122
+
123
+ # Load TensorFlow parameters into PyTorch model
124
+ load_param(checkpoint_file, conversion_table, model_name)
125
+ return conversion_table
126
+
127
+
128
+ def load_and_save_temporary_tensorflow_model(model_name, model_ckpt, example_img= '../../example/img.jpg'):
129
+ """ Loads and saves a TensorFlow model. """
130
+ image_files = [example_img]
131
+ eval_ckpt_driver = eval_ckpt_main.EvalCkptDriver(model_name)
132
+ with tf.Graph().as_default(), tf.Session() as sess:
133
+ images, labels = eval_ckpt_driver.build_dataset(image_files, [0] * len(image_files), False)
134
+ probs = eval_ckpt_driver.build_model(images, is_training=False)
135
+ sess.run(tf.global_variables_initializer())
136
+ print(model_ckpt)
137
+ eval_ckpt_driver.restore_model(sess, model_ckpt)
138
+ tf.train.Saver().save(sess, 'tmp/model.ckpt')
139
+
140
+
141
+ if __name__ == '__main__':
142
+
143
+ import sys
144
+ import argparse
145
+
146
+ sys.path.append('original_tf')
147
+ import eval_ckpt_main
148
+
149
+ from efficientnet_pytorch import EfficientNet
150
+
151
+ parser = argparse.ArgumentParser(
152
+ description='Convert TF model to PyTorch model and save for easier future loading')
153
+ parser.add_argument('--model_name', type=str, default='efficientnet-b0',
154
+ help='efficientnet-b{N}, where N is an integer 0 <= N <= 8')
155
+ parser.add_argument('--tf_checkpoint', type=str, default='pretrained_tensorflow/efficientnet-b0/',
156
+ help='checkpoint file path')
157
+ parser.add_argument('--output_file', type=str, default='pretrained_pytorch/efficientnet-b0.pth',
158
+ help='output PyTorch model file name')
159
+ args = parser.parse_args()
160
+
161
+ # Build model
162
+ model = EfficientNet.from_name(args.model_name)
163
+
164
+ # Load and save temporary TensorFlow file due to TF nuances
165
+ print(args.tf_checkpoint)
166
+ load_and_save_temporary_tensorflow_model(args.model_name, args.tf_checkpoint)
167
+
168
+ # Load weights
169
+ load_efficientnet(model, 'tmp/model.ckpt', model_name=args.model_name)
170
+ print('Loaded TF checkpoint weights')
171
+
172
+ # Save PyTorch file
173
+ torch.save(model.state_dict(), args.output_file)
174
+ print('Saved model to', args.output_file)
external_data/original_tf/__init__.py ADDED
File without changes
external_data/original_tf/efficientnet_builder.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Model Builder for EfficientNet."""
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import division
19
+ from __future__ import print_function
20
+
21
+ import functools
22
+ import os
23
+ import re
24
+ from absl import logging
25
+ import numpy as np
26
+ import six
27
+ import tensorflow.compat.v1 as tf
28
+
29
+ import efficientnet_model
30
+ import utils
31
+ MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255]
32
+ STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255]
33
+
34
+
35
+ def efficientnet_params(model_name):
36
+ """Get efficientnet params based on model name."""
37
+ params_dict = {
38
+ # (width_coefficient, depth_coefficient, resolution, dropout_rate)
39
+ 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
40
+ 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
41
+ 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
42
+ 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
43
+ 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
44
+ 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
45
+ 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
46
+ 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
47
+ 'efficientnet-b8': (2.2, 3.6, 672, 0.5),
48
+ 'efficientnet-l2': (4.3, 5.3, 800, 0.5),
49
+ }
50
+ return params_dict[model_name]
51
+
52
+
53
+ class BlockDecoder(object):
54
+ """Block Decoder for readability."""
55
+
56
+ def _decode_block_string(self, block_string):
57
+ """Gets a block through a string notation of arguments."""
58
+ if six.PY2:
59
+ assert isinstance(block_string, (str, unicode))
60
+ else:
61
+ assert isinstance(block_string, str)
62
+ ops = block_string.split('_')
63
+ options = {}
64
+ for op in ops:
65
+ splits = re.split(r'(\d.*)', op)
66
+ if len(splits) >= 2:
67
+ key, value = splits[:2]
68
+ options[key] = value
69
+
70
+ if 's' not in options or len(options['s']) != 2:
71
+ raise ValueError('Strides options should be a pair of integers.')
72
+
73
+ return efficientnet_model.BlockArgs(
74
+ kernel_size=int(options['k']),
75
+ num_repeat=int(options['r']),
76
+ input_filters=int(options['i']),
77
+ output_filters=int(options['o']),
78
+ expand_ratio=int(options['e']),
79
+ id_skip=('noskip' not in block_string),
80
+ se_ratio=float(options['se']) if 'se' in options else None,
81
+ strides=[int(options['s'][0]),
82
+ int(options['s'][1])],
83
+ conv_type=int(options['c']) if 'c' in options else 0,
84
+ fused_conv=int(options['f']) if 'f' in options else 0,
85
+ super_pixel=int(options['p']) if 'p' in options else 0,
86
+ condconv=('cc' in block_string))
87
+
88
+ def _encode_block_string(self, block):
89
+ """Encodes a block to a string."""
90
+ args = [
91
+ 'r%d' % block.num_repeat,
92
+ 'k%d' % block.kernel_size,
93
+ 's%d%d' % (block.strides[0], block.strides[1]),
94
+ 'e%s' % block.expand_ratio,
95
+ 'i%d' % block.input_filters,
96
+ 'o%d' % block.output_filters,
97
+ 'c%d' % block.conv_type,
98
+ 'f%d' % block.fused_conv,
99
+ 'p%d' % block.super_pixel,
100
+ ]
101
+ if block.se_ratio > 0 and block.se_ratio <= 1:
102
+ args.append('se%s' % block.se_ratio)
103
+ if block.id_skip is False: # pylint: disable=g-bool-id-comparison
104
+ args.append('noskip')
105
+ if block.condconv:
106
+ args.append('cc')
107
+ return '_'.join(args)
108
+
109
+ def decode(self, string_list):
110
+ """Decodes a list of string notations to specify blocks inside the network.
111
+
112
+ Args:
113
+ string_list: a list of strings, each string is a notation of block.
114
+
115
+ Returns:
116
+ A list of namedtuples to represent blocks arguments.
117
+ """
118
+ assert isinstance(string_list, list)
119
+ blocks_args = []
120
+ for block_string in string_list:
121
+ blocks_args.append(self._decode_block_string(block_string))
122
+ return blocks_args
123
+
124
+ def encode(self, blocks_args):
125
+ """Encodes a list of Blocks to a list of strings.
126
+
127
+ Args:
128
+ blocks_args: A list of namedtuples to represent blocks arguments.
129
+ Returns:
130
+ a list of strings, each string is a notation of block.
131
+ """
132
+ block_strings = []
133
+ for block in blocks_args:
134
+ block_strings.append(self._encode_block_string(block))
135
+ return block_strings
136
+
137
+
138
+ def swish(features, use_native=True, use_hard=False):
139
+ """Computes the Swish activation function.
140
+
141
+ We provide three alternnatives:
142
+ - Native tf.nn.swish, use less memory during training than composable swish.
143
+ - Quantization friendly hard swish.
144
+ - A composable swish, equivalant to tf.nn.swish, but more general for
145
+ finetuning and TF-Hub.
146
+
147
+ Args:
148
+ features: A `Tensor` representing preactivation values.
149
+ use_native: Whether to use the native swish from tf.nn that uses a custom
150
+ gradient to reduce memory usage, or to use customized swish that uses
151
+ default TensorFlow gradient computation.
152
+ use_hard: Whether to use quantization-friendly hard swish.
153
+
154
+ Returns:
155
+ The activation value.
156
+ """
157
+ if use_native and use_hard:
158
+ raise ValueError('Cannot specify both use_native and use_hard.')
159
+
160
+ if use_native:
161
+ return tf.nn.swish(features)
162
+
163
+ if use_hard:
164
+ return features * tf.nn.relu6(features + np.float32(3)) * (1. / 6.)
165
+
166
+ features = tf.convert_to_tensor(features, name='features')
167
+ return features * tf.nn.sigmoid(features)
168
+
169
+
170
+ _DEFAULT_BLOCKS_ARGS = [
171
+ 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25',
172
+ 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25',
173
+ 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25',
174
+ 'r1_k3_s11_e6_i192_o320_se0.25',
175
+ ]
176
+
177
+
178
+ def efficientnet(width_coefficient=None,
179
+ depth_coefficient=None,
180
+ dropout_rate=0.2,
181
+ survival_prob=0.8):
182
+ """Creates a efficientnet model."""
183
+ global_params = efficientnet_model.GlobalParams(
184
+ blocks_args=_DEFAULT_BLOCKS_ARGS,
185
+ batch_norm_momentum=0.99,
186
+ batch_norm_epsilon=1e-3,
187
+ dropout_rate=dropout_rate,
188
+ survival_prob=survival_prob,
189
+ data_format='channels_last',
190
+ num_classes=1000,
191
+ width_coefficient=width_coefficient,
192
+ depth_coefficient=depth_coefficient,
193
+ depth_divisor=8,
194
+ min_depth=None,
195
+ relu_fn=tf.nn.swish,
196
+ # The default is TPU-specific batch norm.
197
+ # The alternative is tf.layers.BatchNormalization.
198
+ batch_norm=utils.TpuBatchNormalization, # TPU-specific requirement.
199
+ use_se=True,
200
+ clip_projection_output=False)
201
+ return global_params
202
+
203
+
204
+ def get_model_params(model_name, override_params):
205
+ """Get the block args and global params for a given model."""
206
+ if model_name.startswith('efficientnet'):
207
+ width_coefficient, depth_coefficient, _, dropout_rate = (
208
+ efficientnet_params(model_name))
209
+ global_params = efficientnet(
210
+ width_coefficient, depth_coefficient, dropout_rate)
211
+ else:
212
+ raise NotImplementedError('model name is not pre-defined: %s' % model_name)
213
+
214
+ if override_params:
215
+ # ValueError will be raised here if override_params has fields not included
216
+ # in global_params.
217
+ global_params = global_params._replace(**override_params)
218
+
219
+ decoder = BlockDecoder()
220
+ blocks_args = decoder.decode(global_params.blocks_args)
221
+
222
+ logging.info('global_params= %s', global_params)
223
+ return blocks_args, global_params
224
+
225
+
226
+ def build_model(images,
227
+ model_name,
228
+ training,
229
+ override_params=None,
230
+ model_dir=None,
231
+ fine_tuning=False,
232
+ features_only=False,
233
+ pooled_features_only=False):
234
+ """A helper functiion to creates a model and returns predicted logits.
235
+
236
+ Args:
237
+ images: input images tensor.
238
+ model_name: string, the predefined model name.
239
+ training: boolean, whether the model is constructed for training.
240
+ override_params: A dictionary of params for overriding. Fields must exist in
241
+ efficientnet_model.GlobalParams.
242
+ model_dir: string, optional model dir for saving configs.
243
+ fine_tuning: boolean, whether the model is used for finetuning.
244
+ features_only: build the base feature network only (excluding final
245
+ 1x1 conv layer, global pooling, dropout and fc head).
246
+ pooled_features_only: build the base network for features extraction (after
247
+ 1x1 conv layer and global pooling, but before dropout and fc head).
248
+
249
+ Returns:
250
+ logits: the logits tensor of classes.
251
+ endpoints: the endpoints for each layer.
252
+
253
+ Raises:
254
+ When model_name specified an undefined model, raises NotImplementedError.
255
+ When override_params has invalid fields, raises ValueError.
256
+ """
257
+ assert isinstance(images, tf.Tensor)
258
+ assert not (features_only and pooled_features_only)
259
+
260
+ # For backward compatibility.
261
+ if override_params and override_params.get('drop_connect_rate', None):
262
+ override_params['survival_prob'] = 1 - override_params['drop_connect_rate']
263
+
264
+ if not training or fine_tuning:
265
+ if not override_params:
266
+ override_params = {}
267
+ override_params['batch_norm'] = utils.BatchNormalization
268
+ if fine_tuning:
269
+ override_params['relu_fn'] = functools.partial(swish, use_native=False)
270
+ blocks_args, global_params = get_model_params(model_name, override_params)
271
+
272
+ if model_dir:
273
+ param_file = os.path.join(model_dir, 'model_params.txt')
274
+ if not tf.gfile.Exists(param_file):
275
+ if not tf.gfile.Exists(model_dir):
276
+ tf.gfile.MakeDirs(model_dir)
277
+ with tf.gfile.GFile(param_file, 'w') as f:
278
+ logging.info('writing to %s', param_file)
279
+ f.write('model_name= %s\n\n' % model_name)
280
+ f.write('global_params= %s\n\n' % str(global_params))
281
+ f.write('blocks_args= %s\n\n' % str(blocks_args))
282
+
283
+ with tf.variable_scope(model_name):
284
+ model = efficientnet_model.Model(blocks_args, global_params)
285
+ outputs = model(
286
+ images,
287
+ training=training,
288
+ features_only=features_only,
289
+ pooled_features_only=pooled_features_only)
290
+ if features_only:
291
+ outputs = tf.identity(outputs, 'features')
292
+ elif pooled_features_only:
293
+ outputs = tf.identity(outputs, 'pooled_features')
294
+ else:
295
+ outputs = tf.identity(outputs, 'logits')
296
+ return outputs, model.endpoints
297
+
298
+
299
+ def build_model_base(images, model_name, training, override_params=None):
300
+ """A helper functiion to create a base model and return global_pool.
301
+
302
+ Args:
303
+ images: input images tensor.
304
+ model_name: string, the predefined model name.
305
+ training: boolean, whether the model is constructed for training.
306
+ override_params: A dictionary of params for overriding. Fields must exist in
307
+ efficientnet_model.GlobalParams.
308
+
309
+ Returns:
310
+ features: global pool features.
311
+ endpoints: the endpoints for each layer.
312
+
313
+ Raises:
314
+ When model_name specified an undefined model, raises NotImplementedError.
315
+ When override_params has invalid fields, raises ValueError.
316
+ """
317
+ assert isinstance(images, tf.Tensor)
318
+ # For backward compatibility.
319
+ if override_params and override_params.get('drop_connect_rate', None):
320
+ override_params['survival_prob'] = 1 - override_params['drop_connect_rate']
321
+
322
+ blocks_args, global_params = get_model_params(model_name, override_params)
323
+
324
+ with tf.variable_scope(model_name):
325
+ model = efficientnet_model.Model(blocks_args, global_params)
326
+ features = model(images, training=training, features_only=True)
327
+
328
+ features = tf.identity(features, 'features')
329
+ return features, model.endpoints
external_data/original_tf/efficientnet_model.py ADDED
@@ -0,0 +1,713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Contains definitions for EfficientNet model.
16
+
17
+ [1] Mingxing Tan, Quoc V. Le
18
+ EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks.
19
+ ICML'19, https://arxiv.org/abs/1905.11946
20
+ """
21
+
22
+ from __future__ import absolute_import
23
+ from __future__ import division
24
+ from __future__ import print_function
25
+
26
+ import collections
27
+ import functools
28
+ import math
29
+
30
+ from absl import logging
31
+ import numpy as np
32
+ import six
33
+ from six.moves import xrange
34
+ import tensorflow.compat.v1 as tf
35
+
36
+ import utils
37
+ # from condconv import condconv_layers
38
+
39
+ GlobalParams = collections.namedtuple('GlobalParams', [
40
+ 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 'data_format',
41
+ 'num_classes', 'width_coefficient', 'depth_coefficient', 'depth_divisor',
42
+ 'min_depth', 'survival_prob', 'relu_fn', 'batch_norm', 'use_se',
43
+ 'local_pooling', 'condconv_num_experts', 'clip_projection_output',
44
+ 'blocks_args'
45
+ ])
46
+ GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
47
+
48
+ BlockArgs = collections.namedtuple('BlockArgs', [
49
+ 'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
50
+ 'expand_ratio', 'id_skip', 'strides', 'se_ratio', 'conv_type', 'fused_conv',
51
+ 'super_pixel', 'condconv'
52
+ ])
53
+ # defaults will be a public argument for namedtuple in Python 3.7
54
+ # https://docs.python.org/3/library/collections.html#collections.namedtuple
55
+ BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
56
+
57
+
58
+ def conv_kernel_initializer(shape, dtype=None, partition_info=None):
59
+ """Initialization for convolutional kernels.
60
+
61
+ The main difference with tf.variance_scaling_initializer is that
62
+ tf.variance_scaling_initializer uses a truncated normal with an uncorrected
63
+ standard deviation, whereas here we use a normal distribution. Similarly,
64
+ tf.initializers.variance_scaling uses a truncated normal with
65
+ a corrected standard deviation.
66
+
67
+ Args:
68
+ shape: shape of variable
69
+ dtype: dtype of variable
70
+ partition_info: unused
71
+
72
+ Returns:
73
+ an initialization for the variable
74
+ """
75
+ del partition_info
76
+ kernel_height, kernel_width, _, out_filters = shape
77
+ fan_out = int(kernel_height * kernel_width * out_filters)
78
+ return tf.random_normal(
79
+ shape, mean=0.0, stddev=np.sqrt(2.0 / fan_out), dtype=dtype)
80
+
81
+
82
+ def dense_kernel_initializer(shape, dtype=None, partition_info=None):
83
+ """Initialization for dense kernels.
84
+
85
+ This initialization is equal to
86
+ tf.variance_scaling_initializer(scale=1.0/3.0, mode='fan_out',
87
+ distribution='uniform').
88
+ It is written out explicitly here for clarity.
89
+
90
+ Args:
91
+ shape: shape of variable
92
+ dtype: dtype of variable
93
+ partition_info: unused
94
+
95
+ Returns:
96
+ an initialization for the variable
97
+ """
98
+ del partition_info
99
+ init_range = 1.0 / np.sqrt(shape[1])
100
+ return tf.random_uniform(shape, -init_range, init_range, dtype=dtype)
101
+
102
+
103
+ def superpixel_kernel_initializer(shape, dtype='float32', partition_info=None):
104
+ """Initializes superpixel kernels.
105
+
106
+ This is inspired by space-to-depth transformation that is mathematically
107
+ equivalent before and after the transformation. But we do the space-to-depth
108
+ via a convolution. Moreover, we make the layer trainable instead of direct
109
+ transform, we can initialization it this way so that the model can learn not
110
+ to do anything but keep it mathematically equivalent, when improving
111
+ performance.
112
+
113
+
114
+ Args:
115
+ shape: shape of variable
116
+ dtype: dtype of variable
117
+ partition_info: unused
118
+
119
+ Returns:
120
+ an initialization for the variable
121
+ """
122
+ del partition_info
123
+ # use input depth to make superpixel kernel.
124
+ depth = shape[-2]
125
+ filters = np.zeros([2, 2, depth, 4 * depth], dtype=dtype)
126
+ i = np.arange(2)
127
+ j = np.arange(2)
128
+ k = np.arange(depth)
129
+ mesh = np.array(np.meshgrid(i, j, k)).T.reshape(-1, 3).T
130
+ filters[
131
+ mesh[0],
132
+ mesh[1],
133
+ mesh[2],
134
+ 4 * mesh[2] + 2 * mesh[0] + mesh[1]] = 1
135
+ return filters
136
+
137
+
138
+ def round_filters(filters, global_params):
139
+ """Round number of filters based on depth multiplier."""
140
+ orig_f = filters
141
+ multiplier = global_params.width_coefficient
142
+ divisor = global_params.depth_divisor
143
+ min_depth = global_params.min_depth
144
+ if not multiplier:
145
+ return filters
146
+
147
+ filters *= multiplier
148
+ min_depth = min_depth or divisor
149
+ new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
150
+ # Make sure that round down does not go down by more than 10%.
151
+ if new_filters < 0.9 * filters:
152
+ new_filters += divisor
153
+ logging.info('round_filter input=%s output=%s', orig_f, new_filters)
154
+ return int(new_filters)
155
+
156
+
157
+ def round_repeats(repeats, global_params):
158
+ """Round number of filters based on depth multiplier."""
159
+ multiplier = global_params.depth_coefficient
160
+ if not multiplier:
161
+ return repeats
162
+ return int(math.ceil(multiplier * repeats))
163
+
164
+
165
+ class MBConvBlock(tf.keras.layers.Layer):
166
+ """A class of MBConv: Mobile Inverted Residual Bottleneck.
167
+
168
+ Attributes:
169
+ endpoints: dict. A list of internal tensors.
170
+ """
171
+
172
+ def __init__(self, block_args, global_params):
173
+ """Initializes a MBConv block.
174
+
175
+ Args:
176
+ block_args: BlockArgs, arguments to create a Block.
177
+ global_params: GlobalParams, a set of global parameters.
178
+ """
179
+ super(MBConvBlock, self).__init__()
180
+ self._block_args = block_args
181
+ self._batch_norm_momentum = global_params.batch_norm_momentum
182
+ self._batch_norm_epsilon = global_params.batch_norm_epsilon
183
+ self._batch_norm = global_params.batch_norm
184
+ self._condconv_num_experts = global_params.condconv_num_experts
185
+ self._data_format = global_params.data_format
186
+ if self._data_format == 'channels_first':
187
+ self._channel_axis = 1
188
+ self._spatial_dims = [2, 3]
189
+ else:
190
+ self._channel_axis = -1
191
+ self._spatial_dims = [1, 2]
192
+
193
+ self._relu_fn = global_params.relu_fn or tf.nn.swish
194
+ self._has_se = (
195
+ global_params.use_se and self._block_args.se_ratio is not None and
196
+ 0 < self._block_args.se_ratio <= 1)
197
+
198
+ self._clip_projection_output = global_params.clip_projection_output
199
+
200
+ self.endpoints = None
201
+
202
+ self.conv_cls = tf.layers.Conv2D
203
+ self.depthwise_conv_cls = utils.DepthwiseConv2D
204
+ if self._block_args.condconv:
205
+ self.conv_cls = functools.partial(
206
+ condconv_layers.CondConv2D, num_experts=self._condconv_num_experts)
207
+ self.depthwise_conv_cls = functools.partial(
208
+ condconv_layers.DepthwiseCondConv2D,
209
+ num_experts=self._condconv_num_experts)
210
+
211
+ # Builds the block accordings to arguments.
212
+ self._build()
213
+
214
+ def block_args(self):
215
+ return self._block_args
216
+
217
+ def _build(self):
218
+ """Builds block according to the arguments."""
219
+ if self._block_args.super_pixel == 1:
220
+ self._superpixel = tf.layers.Conv2D(
221
+ self._block_args.input_filters,
222
+ kernel_size=[2, 2],
223
+ strides=[2, 2],
224
+ kernel_initializer=conv_kernel_initializer,
225
+ padding='same',
226
+ data_format=self._data_format,
227
+ use_bias=False)
228
+ self._bnsp = self._batch_norm(
229
+ axis=self._channel_axis,
230
+ momentum=self._batch_norm_momentum,
231
+ epsilon=self._batch_norm_epsilon)
232
+
233
+ if self._block_args.condconv:
234
+ # Add the example-dependent routing function
235
+ self._avg_pooling = tf.keras.layers.GlobalAveragePooling2D(
236
+ data_format=self._data_format)
237
+ self._routing_fn = tf.layers.Dense(
238
+ self._condconv_num_experts, activation=tf.nn.sigmoid)
239
+
240
+ filters = self._block_args.input_filters * self._block_args.expand_ratio
241
+ kernel_size = self._block_args.kernel_size
242
+
243
+ # Fused expansion phase. Called if using fused convolutions.
244
+ self._fused_conv = self.conv_cls(
245
+ filters=filters,
246
+ kernel_size=[kernel_size, kernel_size],
247
+ strides=self._block_args.strides,
248
+ kernel_initializer=conv_kernel_initializer,
249
+ padding='same',
250
+ data_format=self._data_format,
251
+ use_bias=False)
252
+
253
+ # Expansion phase. Called if not using fused convolutions and expansion
254
+ # phase is necessary.
255
+ self._expand_conv = self.conv_cls(
256
+ filters=filters,
257
+ kernel_size=[1, 1],
258
+ strides=[1, 1],
259
+ kernel_initializer=conv_kernel_initializer,
260
+ padding='same',
261
+ data_format=self._data_format,
262
+ use_bias=False)
263
+ self._bn0 = self._batch_norm(
264
+ axis=self._channel_axis,
265
+ momentum=self._batch_norm_momentum,
266
+ epsilon=self._batch_norm_epsilon)
267
+
268
+ # Depth-wise convolution phase. Called if not using fused convolutions.
269
+ self._depthwise_conv = self.depthwise_conv_cls(
270
+ kernel_size=[kernel_size, kernel_size],
271
+ strides=self._block_args.strides,
272
+ depthwise_initializer=conv_kernel_initializer,
273
+ padding='same',
274
+ data_format=self._data_format,
275
+ use_bias=False)
276
+
277
+ self._bn1 = self._batch_norm(
278
+ axis=self._channel_axis,
279
+ momentum=self._batch_norm_momentum,
280
+ epsilon=self._batch_norm_epsilon)
281
+
282
+ if self._has_se:
283
+ num_reduced_filters = max(
284
+ 1, int(self._block_args.input_filters * self._block_args.se_ratio))
285
+ # Squeeze and Excitation layer.
286
+ self._se_reduce = tf.layers.Conv2D(
287
+ num_reduced_filters,
288
+ kernel_size=[1, 1],
289
+ strides=[1, 1],
290
+ kernel_initializer=conv_kernel_initializer,
291
+ padding='same',
292
+ data_format=self._data_format,
293
+ use_bias=True)
294
+ self._se_expand = tf.layers.Conv2D(
295
+ filters,
296
+ kernel_size=[1, 1],
297
+ strides=[1, 1],
298
+ kernel_initializer=conv_kernel_initializer,
299
+ padding='same',
300
+ data_format=self._data_format,
301
+ use_bias=True)
302
+
303
+ # Output phase.
304
+ filters = self._block_args.output_filters
305
+ self._project_conv = self.conv_cls(
306
+ filters=filters,
307
+ kernel_size=[1, 1],
308
+ strides=[1, 1],
309
+ kernel_initializer=conv_kernel_initializer,
310
+ padding='same',
311
+ data_format=self._data_format,
312
+ use_bias=False)
313
+ self._bn2 = self._batch_norm(
314
+ axis=self._channel_axis,
315
+ momentum=self._batch_norm_momentum,
316
+ epsilon=self._batch_norm_epsilon)
317
+
318
+ def _call_se(self, input_tensor):
319
+ """Call Squeeze and Excitation layer.
320
+
321
+ Args:
322
+ input_tensor: Tensor, a single input tensor for Squeeze/Excitation layer.
323
+
324
+ Returns:
325
+ A output tensor, which should have the same shape as input.
326
+ """
327
+ se_tensor = tf.reduce_mean(input_tensor, self._spatial_dims, keepdims=True)
328
+ se_tensor = self._se_expand(self._relu_fn(self._se_reduce(se_tensor)))
329
+ logging.info('Built Squeeze and Excitation with tensor shape: %s',
330
+ (se_tensor.shape))
331
+ return tf.sigmoid(se_tensor) * input_tensor
332
+
333
+ def call(self, inputs, training=True, survival_prob=None):
334
+ """Implementation of call().
335
+
336
+ Args:
337
+ inputs: the inputs tensor.
338
+ training: boolean, whether the model is constructed for training.
339
+ survival_prob: float, between 0 to 1, drop connect rate.
340
+
341
+ Returns:
342
+ A output tensor.
343
+ """
344
+ logging.info('Block input: %s shape: %s', inputs.name, inputs.shape)
345
+ logging.info('Block input depth: %s output depth: %s',
346
+ self._block_args.input_filters,
347
+ self._block_args.output_filters)
348
+
349
+ x = inputs
350
+
351
+ fused_conv_fn = self._fused_conv
352
+ expand_conv_fn = self._expand_conv
353
+ depthwise_conv_fn = self._depthwise_conv
354
+ project_conv_fn = self._project_conv
355
+
356
+ if self._block_args.condconv:
357
+ pooled_inputs = self._avg_pooling(inputs)
358
+ routing_weights = self._routing_fn(pooled_inputs)
359
+ # Capture routing weights as additional input to CondConv layers
360
+ fused_conv_fn = functools.partial(
361
+ self._fused_conv, routing_weights=routing_weights)
362
+ expand_conv_fn = functools.partial(
363
+ self._expand_conv, routing_weights=routing_weights)
364
+ depthwise_conv_fn = functools.partial(
365
+ self._depthwise_conv, routing_weights=routing_weights)
366
+ project_conv_fn = functools.partial(
367
+ self._project_conv, routing_weights=routing_weights)
368
+
369
+ # creates conv 2x2 kernel
370
+ if self._block_args.super_pixel == 1:
371
+ with tf.variable_scope('super_pixel'):
372
+ x = self._relu_fn(
373
+ self._bnsp(self._superpixel(x), training=training))
374
+ logging.info(
375
+ 'Block start with SuperPixel: %s shape: %s', x.name, x.shape)
376
+
377
+ if self._block_args.fused_conv:
378
+ # If use fused mbconv, skip expansion and use regular conv.
379
+ x = self._relu_fn(self._bn1(fused_conv_fn(x), training=training))
380
+ logging.info('Conv2D: %s shape: %s', x.name, x.shape)
381
+ else:
382
+ # Otherwise, first apply expansion and then apply depthwise conv.
383
+ if self._block_args.expand_ratio != 1:
384
+ x = self._relu_fn(self._bn0(expand_conv_fn(x), training=training))
385
+ logging.info('Expand: %s shape: %s', x.name, x.shape)
386
+
387
+ x = self._relu_fn(self._bn1(depthwise_conv_fn(x), training=training))
388
+ logging.info('DWConv: %s shape: %s', x.name, x.shape)
389
+
390
+ if self._has_se:
391
+ with tf.variable_scope('se'):
392
+ x = self._call_se(x)
393
+
394
+ self.endpoints = {'expansion_output': x}
395
+
396
+ x = self._bn2(project_conv_fn(x), training=training)
397
+ # Add identity so that quantization-aware training can insert quantization
398
+ # ops correctly.
399
+ x = tf.identity(x)
400
+ if self._clip_projection_output:
401
+ x = tf.clip_by_value(x, -6, 6)
402
+ if self._block_args.id_skip:
403
+ if all(
404
+ s == 1 for s in self._block_args.strides
405
+ ) and self._block_args.input_filters == self._block_args.output_filters:
406
+ # Apply only if skip connection presents.
407
+ if survival_prob:
408
+ x = utils.drop_connect(x, training, survival_prob)
409
+ x = tf.add(x, inputs)
410
+ logging.info('Project: %s shape: %s', x.name, x.shape)
411
+ return x
412
+
413
+
414
+ class MBConvBlockWithoutDepthwise(MBConvBlock):
415
+ """MBConv-like block without depthwise convolution and squeeze-and-excite."""
416
+
417
+ def _build(self):
418
+ """Builds block according to the arguments."""
419
+ filters = self._block_args.input_filters * self._block_args.expand_ratio
420
+ if self._block_args.expand_ratio != 1:
421
+ # Expansion phase:
422
+ self._expand_conv = tf.layers.Conv2D(
423
+ filters,
424
+ kernel_size=[3, 3],
425
+ strides=[1, 1],
426
+ kernel_initializer=conv_kernel_initializer,
427
+ padding='same',
428
+ use_bias=False)
429
+ self._bn0 = self._batch_norm(
430
+ axis=self._channel_axis,
431
+ momentum=self._batch_norm_momentum,
432
+ epsilon=self._batch_norm_epsilon)
433
+
434
+ # Output phase:
435
+ filters = self._block_args.output_filters
436
+ self._project_conv = tf.layers.Conv2D(
437
+ filters,
438
+ kernel_size=[1, 1],
439
+ strides=self._block_args.strides,
440
+ kernel_initializer=conv_kernel_initializer,
441
+ padding='same',
442
+ use_bias=False)
443
+ self._bn1 = self._batch_norm(
444
+ axis=self._channel_axis,
445
+ momentum=self._batch_norm_momentum,
446
+ epsilon=self._batch_norm_epsilon)
447
+
448
+ def call(self, inputs, training=True, survival_prob=None):
449
+ """Implementation of call().
450
+
451
+ Args:
452
+ inputs: the inputs tensor.
453
+ training: boolean, whether the model is constructed for training.
454
+ survival_prob: float, between 0 to 1, drop connect rate.
455
+
456
+ Returns:
457
+ A output tensor.
458
+ """
459
+ logging.info('Block input: %s shape: %s', inputs.name, inputs.shape)
460
+ if self._block_args.expand_ratio != 1:
461
+ x = self._relu_fn(self._bn0(self._expand_conv(inputs), training=training))
462
+ else:
463
+ x = inputs
464
+ logging.info('Expand: %s shape: %s', x.name, x.shape)
465
+
466
+ self.endpoints = {'expansion_output': x}
467
+
468
+ x = self._bn1(self._project_conv(x), training=training)
469
+ # Add identity so that quantization-aware training can insert quantization
470
+ # ops correctly.
471
+ x = tf.identity(x)
472
+ if self._clip_projection_output:
473
+ x = tf.clip_by_value(x, -6, 6)
474
+
475
+ if self._block_args.id_skip:
476
+ if all(
477
+ s == 1 for s in self._block_args.strides
478
+ ) and self._block_args.input_filters == self._block_args.output_filters:
479
+ # Apply only if skip connection presents.
480
+ if survival_prob:
481
+ x = utils.drop_connect(x, training, survival_prob)
482
+ x = tf.add(x, inputs)
483
+ logging.info('Project: %s shape: %s', x.name, x.shape)
484
+ return x
485
+
486
+
487
+ class Model(tf.keras.Model):
488
+ """A class implements tf.keras.Model for MNAS-like model.
489
+
490
+ Reference: https://arxiv.org/abs/1807.11626
491
+ """
492
+
493
+ def __init__(self, blocks_args=None, global_params=None):
494
+ """Initializes an `Model` instance.
495
+
496
+ Args:
497
+ blocks_args: A list of BlockArgs to construct block modules.
498
+ global_params: GlobalParams, a set of global parameters.
499
+
500
+ Raises:
501
+ ValueError: when blocks_args is not specified as a list.
502
+ """
503
+ super(Model, self).__init__()
504
+ if not isinstance(blocks_args, list):
505
+ raise ValueError('blocks_args should be a list.')
506
+ self._global_params = global_params
507
+ self._blocks_args = blocks_args
508
+ self._relu_fn = global_params.relu_fn or tf.nn.swish
509
+ self._batch_norm = global_params.batch_norm
510
+
511
+ self.endpoints = None
512
+
513
+ self._build()
514
+
515
+ def _get_conv_block(self, conv_type):
516
+ conv_block_map = {0: MBConvBlock, 1: MBConvBlockWithoutDepthwise}
517
+ return conv_block_map[conv_type]
518
+
519
+ def _build(self):
520
+ """Builds a model."""
521
+ self._blocks = []
522
+ batch_norm_momentum = self._global_params.batch_norm_momentum
523
+ batch_norm_epsilon = self._global_params.batch_norm_epsilon
524
+ if self._global_params.data_format == 'channels_first':
525
+ channel_axis = 1
526
+ self._spatial_dims = [2, 3]
527
+ else:
528
+ channel_axis = -1
529
+ self._spatial_dims = [1, 2]
530
+
531
+ # Stem part.
532
+ self._conv_stem = tf.layers.Conv2D(
533
+ filters=round_filters(32, self._global_params),
534
+ kernel_size=[3, 3],
535
+ strides=[2, 2],
536
+ kernel_initializer=conv_kernel_initializer,
537
+ padding='same',
538
+ data_format=self._global_params.data_format,
539
+ use_bias=False)
540
+ self._bn0 = self._batch_norm(
541
+ axis=channel_axis,
542
+ momentum=batch_norm_momentum,
543
+ epsilon=batch_norm_epsilon)
544
+
545
+ # Builds blocks.
546
+ for block_args in self._blocks_args:
547
+ assert block_args.num_repeat > 0
548
+ assert block_args.super_pixel in [0, 1, 2]
549
+ # Update block input and output filters based on depth multiplier.
550
+ input_filters = round_filters(block_args.input_filters,
551
+ self._global_params)
552
+ output_filters = round_filters(block_args.output_filters,
553
+ self._global_params)
554
+ kernel_size = block_args.kernel_size
555
+ block_args = block_args._replace(
556
+ input_filters=input_filters,
557
+ output_filters=output_filters,
558
+ num_repeat=round_repeats(block_args.num_repeat, self._global_params))
559
+
560
+ # The first block needs to take care of stride and filter size increase.
561
+ conv_block = self._get_conv_block(block_args.conv_type)
562
+ if not block_args.super_pixel: # no super_pixel at all
563
+ self._blocks.append(conv_block(block_args, self._global_params))
564
+ else:
565
+ # if superpixel, adjust filters, kernels, and strides.
566
+ depth_factor = int(4 / block_args.strides[0] / block_args.strides[1])
567
+ block_args = block_args._replace(
568
+ input_filters=block_args.input_filters * depth_factor,
569
+ output_filters=block_args.output_filters * depth_factor,
570
+ kernel_size=((block_args.kernel_size + 1) // 2 if depth_factor > 1
571
+ else block_args.kernel_size))
572
+ # if the first block has stride-2 and super_pixel trandformation
573
+ if (block_args.strides[0] == 2 and block_args.strides[1] == 2):
574
+ block_args = block_args._replace(strides=[1, 1])
575
+ self._blocks.append(conv_block(block_args, self._global_params))
576
+ block_args = block_args._replace( # sp stops at stride-2
577
+ super_pixel=0,
578
+ input_filters=input_filters,
579
+ output_filters=output_filters,
580
+ kernel_size=kernel_size)
581
+ elif block_args.super_pixel == 1:
582
+ self._blocks.append(conv_block(block_args, self._global_params))
583
+ block_args = block_args._replace(super_pixel=2)
584
+ else:
585
+ self._blocks.append(conv_block(block_args, self._global_params))
586
+ if block_args.num_repeat > 1: # rest of blocks with the same block_arg
587
+ # pylint: disable=protected-access
588
+ block_args = block_args._replace(
589
+ input_filters=block_args.output_filters, strides=[1, 1])
590
+ # pylint: enable=protected-access
591
+ for _ in xrange(block_args.num_repeat - 1):
592
+ self._blocks.append(conv_block(block_args, self._global_params))
593
+
594
+ # Head part.
595
+ self._conv_head = tf.layers.Conv2D(
596
+ filters=round_filters(1280, self._global_params),
597
+ kernel_size=[1, 1],
598
+ strides=[1, 1],
599
+ kernel_initializer=conv_kernel_initializer,
600
+ padding='same',
601
+ use_bias=False)
602
+ self._bn1 = self._batch_norm(
603
+ axis=channel_axis,
604
+ momentum=batch_norm_momentum,
605
+ epsilon=batch_norm_epsilon)
606
+
607
+ self._avg_pooling = tf.keras.layers.GlobalAveragePooling2D(
608
+ data_format=self._global_params.data_format)
609
+ if self._global_params.num_classes:
610
+ self._fc = tf.layers.Dense(
611
+ self._global_params.num_classes,
612
+ kernel_initializer=dense_kernel_initializer)
613
+ else:
614
+ self._fc = None
615
+
616
+ if self._global_params.dropout_rate > 0:
617
+ self._dropout = tf.keras.layers.Dropout(self._global_params.dropout_rate)
618
+ else:
619
+ self._dropout = None
620
+
621
+ def call(self,
622
+ inputs,
623
+ training=True,
624
+ features_only=None,
625
+ pooled_features_only=False):
626
+ """Implementation of call().
627
+
628
+ Args:
629
+ inputs: input tensors.
630
+ training: boolean, whether the model is constructed for training.
631
+ features_only: build the base feature network only.
632
+ pooled_features_only: build the base network for features extraction
633
+ (after 1x1 conv layer and global pooling, but before dropout and fc
634
+ head).
635
+
636
+ Returns:
637
+ output tensors.
638
+ """
639
+ outputs = None
640
+ self.endpoints = {}
641
+ reduction_idx = 0
642
+ # Calls Stem layers
643
+ with tf.variable_scope('stem'):
644
+ outputs = self._relu_fn(
645
+ self._bn0(self._conv_stem(inputs), training=training))
646
+ logging.info('Built stem layers with output shape: %s', outputs.shape)
647
+ self.endpoints['stem'] = outputs
648
+
649
+ # Calls blocks.
650
+ for idx, block in enumerate(self._blocks):
651
+ is_reduction = False # reduction flag for blocks after the stem layer
652
+ # If the first block has super-pixel (space-to-depth) layer, then stem is
653
+ # the first reduction point.
654
+ if (block.block_args().super_pixel == 1 and idx == 0):
655
+ reduction_idx += 1
656
+ self.endpoints['reduction_%s' % reduction_idx] = outputs
657
+
658
+ elif ((idx == len(self._blocks) - 1) or
659
+ self._blocks[idx + 1].block_args().strides[0] > 1):
660
+ is_reduction = True
661
+ reduction_idx += 1
662
+
663
+ with tf.variable_scope('blocks_%s' % idx):
664
+ survival_prob = self._global_params.survival_prob
665
+ if survival_prob:
666
+ drop_rate = 1.0 - survival_prob
667
+ survival_prob = 1.0 - drop_rate * float(idx) / len(self._blocks)
668
+ logging.info('block_%s survival_prob: %s', idx, survival_prob)
669
+ outputs = block.call(
670
+ outputs, training=training, survival_prob=survival_prob)
671
+ self.endpoints['block_%s' % idx] = outputs
672
+ if is_reduction:
673
+ self.endpoints['reduction_%s' % reduction_idx] = outputs
674
+ if block.endpoints:
675
+ for k, v in six.iteritems(block.endpoints):
676
+ self.endpoints['block_%s/%s' % (idx, k)] = v
677
+ if is_reduction:
678
+ self.endpoints['reduction_%s/%s' % (reduction_idx, k)] = v
679
+ self.endpoints['features'] = outputs
680
+
681
+ if not features_only:
682
+ # Calls final layers and returns logits.
683
+ with tf.variable_scope('head'):
684
+ outputs = self._relu_fn(
685
+ self._bn1(self._conv_head(outputs), training=training))
686
+ self.endpoints['head_1x1'] = outputs
687
+
688
+ if self._global_params.local_pooling:
689
+ shape = outputs.get_shape().as_list()
690
+ kernel_size = [
691
+ 1, shape[self._spatial_dims[0]], shape[self._spatial_dims[1]], 1]
692
+ outputs = tf.nn.avg_pool(
693
+ outputs, ksize=kernel_size, strides=[1, 1, 1, 1], padding='VALID')
694
+ self.endpoints['pooled_features'] = outputs
695
+ if not pooled_features_only:
696
+ if self._dropout:
697
+ outputs = self._dropout(outputs, training=training)
698
+ self.endpoints['global_pool'] = outputs
699
+ if self._fc:
700
+ outputs = tf.squeeze(outputs, self._spatial_dims)
701
+ outputs = self._fc(outputs)
702
+ self.endpoints['head'] = outputs
703
+ else:
704
+ outputs = self._avg_pooling(outputs)
705
+ self.endpoints['pooled_features'] = outputs
706
+ if not pooled_features_only:
707
+ if self._dropout:
708
+ outputs = self._dropout(outputs, training=training)
709
+ self.endpoints['global_pool'] = outputs
710
+ if self._fc:
711
+ outputs = self._fc(outputs)
712
+ self.endpoints['head'] = outputs
713
+ return outputs
external_data/original_tf/eval_ckpt_main.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Eval checkpoint driver.
16
+
17
+ This is an example evaluation script for users to understand the EfficientNet
18
+ model checkpoints on CPU. To serve EfficientNet, please consider to export a
19
+ `SavedModel` from checkpoints and use tf-serving to serve.
20
+ """
21
+
22
+ from __future__ import absolute_import
23
+ from __future__ import division
24
+ from __future__ import print_function
25
+
26
+ import json
27
+ import sys
28
+ from absl import app
29
+ from absl import flags
30
+ import numpy as np
31
+ import tensorflow as tf
32
+
33
+
34
+ import efficientnet_builder
35
+ import preprocessing
36
+
37
+
38
+ flags.DEFINE_string('model_name', 'efficientnet-b0', 'Model name to eval.')
39
+ flags.DEFINE_string('runmode', 'examples', 'Running mode: examples or imagenet')
40
+ flags.DEFINE_string('imagenet_eval_glob', None,
41
+ 'Imagenet eval image glob, '
42
+ 'such as /imagenet/ILSVRC2012*.JPEG')
43
+ flags.DEFINE_string('imagenet_eval_label', None,
44
+ 'Imagenet eval label file path, '
45
+ 'such as /imagenet/ILSVRC2012_validation_ground_truth.txt')
46
+ flags.DEFINE_string('ckpt_dir', '/tmp/ckpt/', 'Checkpoint folders')
47
+ flags.DEFINE_string('example_img', '/tmp/panda.jpg',
48
+ 'Filepath for a single example image.')
49
+ flags.DEFINE_string('labels_map_file', '/tmp/labels_map.txt',
50
+ 'Labels map from label id to its meaning.')
51
+ flags.DEFINE_integer('num_images', 5000,
52
+ 'Number of images to eval. Use -1 to eval all images.')
53
+ FLAGS = flags.FLAGS
54
+
55
+ MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255]
56
+ STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255]
57
+
58
+
59
+ class EvalCkptDriver(object):
60
+ """A driver for running eval inference.
61
+
62
+ Attributes:
63
+ model_name: str. Model name to eval.
64
+ batch_size: int. Eval batch size.
65
+ num_classes: int. Number of classes, default to 1000 for ImageNet.
66
+ image_size: int. Input image size, determined by model name.
67
+ """
68
+
69
+ def __init__(self, model_name='efficientnet-b0', batch_size=1):
70
+ """Initialize internal variables."""
71
+ self.model_name = model_name
72
+ self.batch_size = batch_size
73
+ self.num_classes = 1000
74
+ # Model Scaling parameters
75
+ _, _, self.image_size, _ = efficientnet_builder.efficientnet_params(
76
+ model_name)
77
+
78
+ def restore_model(self, sess, ckpt_dir):
79
+ """Restore variables from checkpoint dir."""
80
+ checkpoint = tf.train.latest_checkpoint(ckpt_dir)
81
+ ema = tf.train.ExponentialMovingAverage(decay=0.9999)
82
+ ema_vars = tf.trainable_variables() + tf.get_collection('moving_vars')
83
+ for v in tf.global_variables():
84
+ if 'moving_mean' in v.name or 'moving_variance' in v.name:
85
+ ema_vars.append(v)
86
+ ema_vars = list(set(ema_vars))
87
+ var_dict = ema.variables_to_restore(ema_vars)
88
+ saver = tf.train.Saver(var_dict, max_to_keep=1)
89
+ saver.restore(sess, checkpoint)
90
+
91
+ def build_model(self, features, is_training):
92
+ """Build model with input features."""
93
+ features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype)
94
+ features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype)
95
+ logits, _ = efficientnet_builder.build_model(
96
+ features, self.model_name, is_training)
97
+ probs = tf.nn.softmax(logits)
98
+ probs = tf.squeeze(probs)
99
+ return probs
100
+
101
+ def build_dataset(self, filenames, labels, is_training):
102
+ """Build input dataset."""
103
+ filenames = tf.constant(filenames)
104
+ labels = tf.constant(labels)
105
+ dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
106
+
107
+ def _parse_function(filename, label):
108
+ image_string = tf.read_file(filename)
109
+ image_decoded = preprocessing.preprocess_image(
110
+ image_string, is_training, self.image_size)
111
+ image = tf.cast(image_decoded, tf.float32)
112
+ return image, label
113
+
114
+ dataset = dataset.map(_parse_function)
115
+ dataset = dataset.batch(self.batch_size)
116
+
117
+ iterator = dataset.make_one_shot_iterator()
118
+ images, labels = iterator.get_next()
119
+ return images, labels
120
+
121
+ def run_inference(self, ckpt_dir, image_files, labels):
122
+ """Build and run inference on the target images and labels."""
123
+ with tf.Graph().as_default(), tf.Session() as sess:
124
+ images, labels = self.build_dataset(image_files, labels, False)
125
+ probs = self.build_model(images, is_training=False)
126
+
127
+ sess.run(tf.global_variables_initializer())
128
+ self.restore_model(sess, ckpt_dir)
129
+
130
+ prediction_idx = []
131
+ prediction_prob = []
132
+ for _ in range(len(image_files) // self.batch_size):
133
+ out_probs = sess.run(probs)
134
+ idx = np.argsort(out_probs)[::-1]
135
+ prediction_idx.append(idx[:5])
136
+ prediction_prob.append([out_probs[pid] for pid in idx[:5]])
137
+
138
+ # Return the top 5 predictions (idx and prob) for each image.
139
+ return prediction_idx, prediction_prob
140
+
141
+
142
+ def eval_example_images(model_name, ckpt_dir, image_files, labels_map_file):
143
+ """Eval a list of example images.
144
+
145
+ Args:
146
+ model_name: str. The name of model to eval.
147
+ ckpt_dir: str. Checkpoint directory path.
148
+ image_files: List[str]. A list of image file paths.
149
+ labels_map_file: str. The labels map file path.
150
+
151
+ Returns:
152
+ A tuple (pred_idx, and pred_prob), where pred_idx is the top 5 prediction
153
+ index and pred_prob is the top 5 prediction probability.
154
+ """
155
+ eval_ckpt_driver = EvalCkptDriver(model_name)
156
+ classes = json.loads(tf.gfile.Open(labels_map_file).read())
157
+ pred_idx, pred_prob = eval_ckpt_driver.run_inference(
158
+ ckpt_dir, image_files, [0] * len(image_files))
159
+ for i in range(len(image_files)):
160
+ print('predicted class for image {}: '.format(image_files[i]))
161
+ for j, idx in enumerate(pred_idx[i]):
162
+ print(' -> top_{} ({:4.2f}%): {} '.format(
163
+ j, pred_prob[i][j] * 100, classes[str(idx)]))
164
+ return pred_idx, pred_prob
165
+
166
+
167
+ def eval_imagenet(model_name,
168
+ ckpt_dir,
169
+ imagenet_eval_glob,
170
+ imagenet_eval_label,
171
+ num_images):
172
+ """Eval ImageNet images and report top1/top5 accuracy.
173
+
174
+ Args:
175
+ model_name: str. The name of model to eval.
176
+ ckpt_dir: str. Checkpoint directory path.
177
+ imagenet_eval_glob: str. File path glob for all eval images.
178
+ imagenet_eval_label: str. File path for eval label.
179
+ num_images: int. Number of images to eval: -1 means eval the whole dataset.
180
+
181
+ Returns:
182
+ A tuple (top1, top5) for top1 and top5 accuracy.
183
+ """
184
+ eval_ckpt_driver = EvalCkptDriver(model_name)
185
+ imagenet_val_labels = [int(i) for i in tf.gfile.GFile(imagenet_eval_label)]
186
+ imagenet_filenames = sorted(tf.gfile.Glob(imagenet_eval_glob))
187
+ if num_images < 0:
188
+ num_images = len(imagenet_filenames)
189
+ image_files = imagenet_filenames[:num_images]
190
+ labels = imagenet_val_labels[:num_images]
191
+
192
+ pred_idx, _ = eval_ckpt_driver.run_inference(ckpt_dir, image_files, labels)
193
+ top1_cnt, top5_cnt = 0.0, 0.0
194
+ for i, label in enumerate(labels):
195
+ top1_cnt += label in pred_idx[i][:1]
196
+ top5_cnt += label in pred_idx[i][:5]
197
+ if i % 100 == 0:
198
+ print('Step {}: top1_acc = {:4.2f}% top5_acc = {:4.2f}%'.format(
199
+ i, 100 * top1_cnt / (i + 1), 100 * top5_cnt / (i + 1)))
200
+ sys.stdout.flush()
201
+ top1, top5 = 100 * top1_cnt / num_images, 100 * top5_cnt / num_images
202
+ print('Final: top1_acc = {:4.2f}% top5_acc = {:4.2f}%'.format(top1, top5))
203
+ return top1, top5
204
+
205
+
206
+ def main(unused_argv):
207
+ tf.logging.set_verbosity(tf.logging.ERROR)
208
+ if FLAGS.runmode == 'examples':
209
+ # Run inference for an example image.
210
+ eval_example_images(FLAGS.model_name, FLAGS.ckpt_dir, [FLAGS.example_img],
211
+ FLAGS.labels_map_file)
212
+ elif FLAGS.runmode == 'imagenet':
213
+ # Run inference for imagenet.
214
+ eval_imagenet(FLAGS.model_name, FLAGS.ckpt_dir, FLAGS.imagenet_eval_glob,
215
+ FLAGS.imagenet_eval_label, FLAGS.num_images)
216
+ else:
217
+ print('must specify runmode: examples or imagenet')
218
+
219
+
220
+ if __name__ == '__main__':
221
+ app.run(main)
external_data/original_tf/preprocessing.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """ImageNet preprocessing."""
16
+ from __future__ import absolute_import
17
+ from __future__ import division
18
+ from __future__ import print_function
19
+
20
+ from absl import logging
21
+
22
+ import tensorflow.compat.v1 as tf
23
+
24
+
25
+ IMAGE_SIZE = 224
26
+ CROP_PADDING = 32
27
+
28
+
29
+ def distorted_bounding_box_crop(image_bytes,
30
+ bbox,
31
+ min_object_covered=0.1,
32
+ aspect_ratio_range=(0.75, 1.33),
33
+ area_range=(0.05, 1.0),
34
+ max_attempts=100,
35
+ scope=None):
36
+ """Generates cropped_image using one of the bboxes randomly distorted.
37
+
38
+ See `tf.image.sample_distorted_bounding_box` for more documentation.
39
+
40
+ Args:
41
+ image_bytes: `Tensor` of binary image data.
42
+ bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
43
+ where each coordinate is [0, 1) and the coordinates are arranged
44
+ as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
45
+ image.
46
+ min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
47
+ area of the image must contain at least this fraction of any bounding
48
+ box supplied.
49
+ aspect_ratio_range: An optional list of `float`s. The cropped area of the
50
+ image must have an aspect ratio = width / height within this range.
51
+ area_range: An optional list of `float`s. The cropped area of the image
52
+ must contain a fraction of the supplied image within in this range.
53
+ max_attempts: An optional `int`. Number of attempts at generating a cropped
54
+ region of the image of the specified constraints. After `max_attempts`
55
+ failures, return the entire image.
56
+ scope: Optional `str` for name scope.
57
+ Returns:
58
+ cropped image `Tensor`
59
+ """
60
+ with tf.name_scope(scope, 'distorted_bounding_box_crop', [image_bytes, bbox]):
61
+ shape = tf.image.extract_jpeg_shape(image_bytes)
62
+ sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
63
+ shape,
64
+ bounding_boxes=bbox,
65
+ min_object_covered=min_object_covered,
66
+ aspect_ratio_range=aspect_ratio_range,
67
+ area_range=area_range,
68
+ max_attempts=max_attempts,
69
+ use_image_if_no_bounding_boxes=True)
70
+ bbox_begin, bbox_size, _ = sample_distorted_bounding_box
71
+
72
+ # Crop the image to the specified bounding box.
73
+ offset_y, offset_x, _ = tf.unstack(bbox_begin)
74
+ target_height, target_width, _ = tf.unstack(bbox_size)
75
+ crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
76
+ image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
77
+
78
+ return image
79
+
80
+
81
+ def _at_least_x_are_equal(a, b, x):
82
+ """At least `x` of `a` and `b` `Tensors` are equal."""
83
+ match = tf.equal(a, b)
84
+ match = tf.cast(match, tf.int32)
85
+ return tf.greater_equal(tf.reduce_sum(match), x)
86
+
87
+
88
+ def _decode_and_random_crop(image_bytes, image_size):
89
+ """Make a random crop of image_size."""
90
+ bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
91
+ image = distorted_bounding_box_crop(
92
+ image_bytes,
93
+ bbox,
94
+ min_object_covered=0.1,
95
+ aspect_ratio_range=(3. / 4, 4. / 3.),
96
+ area_range=(0.08, 1.0),
97
+ max_attempts=10,
98
+ scope=None)
99
+ original_shape = tf.image.extract_jpeg_shape(image_bytes)
100
+ bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3)
101
+
102
+ image = tf.cond(
103
+ bad,
104
+ lambda: _decode_and_center_crop(image_bytes, image_size),
105
+ lambda: tf.image.resize_bicubic([image], # pylint: disable=g-long-lambda
106
+ [image_size, image_size])[0])
107
+
108
+ return image
109
+
110
+
111
+ def _decode_and_center_crop(image_bytes, image_size):
112
+ """Crops to center of image with padding then scales image_size."""
113
+ shape = tf.image.extract_jpeg_shape(image_bytes)
114
+ image_height = shape[0]
115
+ image_width = shape[1]
116
+
117
+ padded_center_crop_size = tf.cast(
118
+ ((image_size / (image_size + CROP_PADDING)) *
119
+ tf.cast(tf.minimum(image_height, image_width), tf.float32)),
120
+ tf.int32)
121
+
122
+ offset_height = ((image_height - padded_center_crop_size) + 1) // 2
123
+ offset_width = ((image_width - padded_center_crop_size) + 1) // 2
124
+ crop_window = tf.stack([offset_height, offset_width,
125
+ padded_center_crop_size, padded_center_crop_size])
126
+ image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
127
+ image = tf.image.resize_bicubic([image], [image_size, image_size])[0]
128
+ return image
129
+
130
+
131
+ def _flip(image):
132
+ """Random horizontal image flip."""
133
+ image = tf.image.random_flip_left_right(image)
134
+ return image
135
+
136
+
137
+ def preprocess_for_train(image_bytes, use_bfloat16, image_size=IMAGE_SIZE,
138
+ augment_name=None,
139
+ randaug_num_layers=None, randaug_magnitude=None):
140
+ """Preprocesses the given image for evaluation.
141
+
142
+ Args:
143
+ image_bytes: `Tensor` representing an image binary of arbitrary size.
144
+ use_bfloat16: `bool` for whether to use bfloat16.
145
+ image_size: image size.
146
+ augment_name: `string` that is the name of the augmentation method
147
+ to apply to the image. `autoaugment` if AutoAugment is to be used or
148
+ `randaugment` if RandAugment is to be used. If the value is `None` no
149
+ augmentation method will be applied applied. See autoaugment.py for more
150
+ details.
151
+ randaug_num_layers: 'int', if RandAug is used, what should the number of
152
+ layers be. See autoaugment.py for detailed description.
153
+ randaug_magnitude: 'int', if RandAug is used, what should the magnitude
154
+ be. See autoaugment.py for detailed description.
155
+
156
+ Returns:
157
+ A preprocessed image `Tensor`.
158
+ """
159
+ image = _decode_and_random_crop(image_bytes, image_size)
160
+ image = _flip(image)
161
+ image = tf.reshape(image, [image_size, image_size, 3])
162
+
163
+ image = tf.image.convert_image_dtype(
164
+ image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32)
165
+
166
+ if augment_name:
167
+ try:
168
+ import autoaugment # pylint: disable=g-import-not-at-top
169
+ except ImportError as e:
170
+ logging.exception('Autoaugment is not supported in TF 2.x.')
171
+ raise e
172
+
173
+ logging.info('Apply AutoAugment policy %s', augment_name)
174
+ input_image_type = image.dtype
175
+ image = tf.clip_by_value(image, 0.0, 255.0)
176
+ image = tf.cast(image, dtype=tf.uint8)
177
+
178
+ if augment_name == 'autoaugment':
179
+ logging.info('Apply AutoAugment policy %s', augment_name)
180
+ image = autoaugment.distort_image_with_autoaugment(image, 'v0')
181
+ elif augment_name == 'randaugment':
182
+ image = autoaugment.distort_image_with_randaugment(
183
+ image, randaug_num_layers, randaug_magnitude)
184
+ else:
185
+ raise ValueError('Invalid value for augment_name: %s' % (augment_name))
186
+
187
+ image = tf.cast(image, dtype=input_image_type)
188
+ return image
189
+
190
+
191
+ def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE):
192
+ """Preprocesses the given image for evaluation.
193
+
194
+ Args:
195
+ image_bytes: `Tensor` representing an image binary of arbitrary size.
196
+ use_bfloat16: `bool` for whether to use bfloat16.
197
+ image_size: image size.
198
+
199
+ Returns:
200
+ A preprocessed image `Tensor`.
201
+ """
202
+ image = _decode_and_center_crop(image_bytes, image_size)
203
+ image = tf.reshape(image, [image_size, image_size, 3])
204
+ image = tf.image.convert_image_dtype(
205
+ image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32)
206
+ return image
207
+
208
+
209
+ def preprocess_image(image_bytes,
210
+ is_training=False,
211
+ use_bfloat16=False,
212
+ image_size=IMAGE_SIZE,
213
+ augment_name=None,
214
+ randaug_num_layers=None,
215
+ randaug_magnitude=None):
216
+ """Preprocesses the given image.
217
+
218
+ Args:
219
+ image_bytes: `Tensor` representing an image binary of arbitrary size.
220
+ is_training: `bool` for whether the preprocessing is for training.
221
+ use_bfloat16: `bool` for whether to use bfloat16.
222
+ image_size: image size.
223
+ augment_name: `string` that is the name of the augmentation method
224
+ to apply to the image. `autoaugment` if AutoAugment is to be used or
225
+ `randaugment` if RandAugment is to be used. If the value is `None` no
226
+ augmentation method will be applied applied. See autoaugment.py for more
227
+ details.
228
+ randaug_num_layers: 'int', if RandAug is used, what should the number of
229
+ layers be. See autoaugment.py for detailed description.
230
+ randaug_magnitude: 'int', if RandAug is used, what should the magnitude
231
+ be. See autoaugment.py for detailed description.
232
+
233
+ Returns:
234
+ A preprocessed image `Tensor` with value range of [0, 255].
235
+ """
236
+ if is_training:
237
+ return preprocess_for_train(
238
+ image_bytes, use_bfloat16, image_size, augment_name,
239
+ randaug_num_layers, randaug_magnitude)
240
+ else:
241
+ return preprocess_for_eval(image_bytes, use_bfloat16, image_size)
external_data/original_tf/utils.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Model utilities."""
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import division
19
+ from __future__ import print_function
20
+
21
+ import json
22
+ import os
23
+ import sys
24
+
25
+ from absl import logging
26
+ import numpy as np
27
+ import tensorflow.compat.v1 as tf
28
+
29
+ from tensorflow.python.tpu import tpu_function # pylint:disable=g-direct-tensorflow-import
30
+
31
+
32
+ def build_learning_rate(initial_lr,
33
+ global_step,
34
+ steps_per_epoch=None,
35
+ lr_decay_type='exponential',
36
+ decay_factor=0.97,
37
+ decay_epochs=2.4,
38
+ total_steps=None,
39
+ warmup_epochs=5):
40
+ """Build learning rate."""
41
+ if lr_decay_type == 'exponential':
42
+ assert steps_per_epoch is not None
43
+ decay_steps = steps_per_epoch * decay_epochs
44
+ lr = tf.train.exponential_decay(
45
+ initial_lr, global_step, decay_steps, decay_factor, staircase=True)
46
+ elif lr_decay_type == 'cosine':
47
+ assert total_steps is not None
48
+ lr = 0.5 * initial_lr * (
49
+ 1 + tf.cos(np.pi * tf.cast(global_step, tf.float32) / total_steps))
50
+ elif lr_decay_type == 'constant':
51
+ lr = initial_lr
52
+ else:
53
+ assert False, 'Unknown lr_decay_type : %s' % lr_decay_type
54
+
55
+ if warmup_epochs:
56
+ logging.info('Learning rate warmup_epochs: %d', warmup_epochs)
57
+ warmup_steps = int(warmup_epochs * steps_per_epoch)
58
+ warmup_lr = (
59
+ initial_lr * tf.cast(global_step, tf.float32) / tf.cast(
60
+ warmup_steps, tf.float32))
61
+ lr = tf.cond(global_step < warmup_steps, lambda: warmup_lr, lambda: lr)
62
+
63
+ return lr
64
+
65
+
66
+ def build_optimizer(learning_rate,
67
+ optimizer_name='rmsprop',
68
+ decay=0.9,
69
+ epsilon=0.001,
70
+ momentum=0.9):
71
+ """Build optimizer."""
72
+ if optimizer_name == 'sgd':
73
+ logging.info('Using SGD optimizer')
74
+ optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
75
+ elif optimizer_name == 'momentum':
76
+ logging.info('Using Momentum optimizer')
77
+ optimizer = tf.train.MomentumOptimizer(
78
+ learning_rate=learning_rate, momentum=momentum)
79
+ elif optimizer_name == 'rmsprop':
80
+ logging.info('Using RMSProp optimizer')
81
+ optimizer = tf.train.RMSPropOptimizer(learning_rate, decay, momentum,
82
+ epsilon)
83
+ else:
84
+ logging.fatal('Unknown optimizer: %s', optimizer_name)
85
+
86
+ return optimizer
87
+
88
+
89
+ class TpuBatchNormalization(tf.layers.BatchNormalization):
90
+ # class TpuBatchNormalization(tf.layers.BatchNormalization):
91
+ """Cross replica batch normalization."""
92
+
93
+ def __init__(self, fused=False, **kwargs):
94
+ if fused in (True, None):
95
+ raise ValueError('TpuBatchNormalization does not support fused=True.')
96
+ super(TpuBatchNormalization, self).__init__(fused=fused, **kwargs)
97
+
98
+ def _cross_replica_average(self, t, num_shards_per_group):
99
+ """Calculates the average value of input tensor across TPU replicas."""
100
+ num_shards = tpu_function.get_tpu_context().number_of_shards
101
+ group_assignment = None
102
+ if num_shards_per_group > 1:
103
+ if num_shards % num_shards_per_group != 0:
104
+ raise ValueError('num_shards: %d mod shards_per_group: %d, should be 0'
105
+ % (num_shards, num_shards_per_group))
106
+ num_groups = num_shards // num_shards_per_group
107
+ group_assignment = [[
108
+ x for x in range(num_shards) if x // num_shards_per_group == y
109
+ ] for y in range(num_groups)]
110
+ return tf.tpu.cross_replica_sum(t, group_assignment) / tf.cast(
111
+ num_shards_per_group, t.dtype)
112
+
113
+ def _moments(self, inputs, reduction_axes, keep_dims):
114
+ """Compute the mean and variance: it overrides the original _moments."""
115
+ shard_mean, shard_variance = super(TpuBatchNormalization, self)._moments(
116
+ inputs, reduction_axes, keep_dims=keep_dims)
117
+
118
+ num_shards = tpu_function.get_tpu_context().number_of_shards or 1
119
+ if num_shards <= 8: # Skip cross_replica for 2x2 or smaller slices.
120
+ num_shards_per_group = 1
121
+ else:
122
+ num_shards_per_group = max(8, num_shards // 8)
123
+ logging.info('TpuBatchNormalization with num_shards_per_group %s',
124
+ num_shards_per_group)
125
+ if num_shards_per_group > 1:
126
+ # Compute variance using: Var[X]= E[X^2] - E[X]^2.
127
+ shard_square_of_mean = tf.math.square(shard_mean)
128
+ shard_mean_of_square = shard_variance + shard_square_of_mean
129
+ group_mean = self._cross_replica_average(
130
+ shard_mean, num_shards_per_group)
131
+ group_mean_of_square = self._cross_replica_average(
132
+ shard_mean_of_square, num_shards_per_group)
133
+ group_variance = group_mean_of_square - tf.math.square(group_mean)
134
+ return (group_mean, group_variance)
135
+ else:
136
+ return (shard_mean, shard_variance)
137
+
138
+
139
+ class BatchNormalization(tf.layers.BatchNormalization):
140
+ """Fixed default name of BatchNormalization to match TpuBatchNormalization."""
141
+
142
+ def __init__(self, name='tpu_batch_normalization', **kwargs):
143
+ super(BatchNormalization, self).__init__(name=name, **kwargs)
144
+
145
+
146
+ def drop_connect(inputs, is_training, survival_prob):
147
+ """Drop the entire conv with given survival probability."""
148
+ # "Deep Networks with Stochastic Depth", https://arxiv.org/pdf/1603.09382.pdf
149
+ if not is_training:
150
+ return inputs
151
+
152
+ # Compute tensor.
153
+ batch_size = tf.shape(inputs)[0]
154
+ random_tensor = survival_prob
155
+ random_tensor += tf.random_uniform([batch_size, 1, 1, 1], dtype=inputs.dtype)
156
+ binary_tensor = tf.floor(random_tensor)
157
+ # Unlike conventional way that multiply survival_prob at test time, here we
158
+ # divide survival_prob at training time, such that no addition compute is
159
+ # needed at test time.
160
+ output = tf.div(inputs, survival_prob) * binary_tensor
161
+ return output
162
+
163
+
164
+ def archive_ckpt(ckpt_eval, ckpt_objective, ckpt_path):
165
+ """Archive a checkpoint if the metric is better."""
166
+ ckpt_dir, ckpt_name = os.path.split(ckpt_path)
167
+
168
+ saved_objective_path = os.path.join(ckpt_dir, 'best_objective.txt')
169
+ saved_objective = float('-inf')
170
+ if tf.gfile.Exists(saved_objective_path):
171
+ with tf.gfile.GFile(saved_objective_path, 'r') as f:
172
+ saved_objective = float(f.read())
173
+ if saved_objective > ckpt_objective:
174
+ logging.info('Ckpt %s is worse than %s', ckpt_objective, saved_objective)
175
+ return False
176
+
177
+ filenames = tf.gfile.Glob(ckpt_path + '.*')
178
+ if filenames is None:
179
+ logging.info('No files to copy for checkpoint %s', ckpt_path)
180
+ return False
181
+
182
+ # Clear the old folder.
183
+ dst_dir = os.path.join(ckpt_dir, 'archive')
184
+ if tf.gfile.Exists(dst_dir):
185
+ tf.gfile.DeleteRecursively(dst_dir)
186
+ tf.gfile.MakeDirs(dst_dir)
187
+
188
+ # Write checkpoints.
189
+ for f in filenames:
190
+ dest = os.path.join(dst_dir, os.path.basename(f))
191
+ tf.gfile.Copy(f, dest, overwrite=True)
192
+ ckpt_state = tf.train.generate_checkpoint_state_proto(
193
+ dst_dir,
194
+ model_checkpoint_path=ckpt_name,
195
+ all_model_checkpoint_paths=[ckpt_name])
196
+ with tf.gfile.GFile(os.path.join(dst_dir, 'checkpoint'), 'w') as f:
197
+ f.write(str(ckpt_state))
198
+ with tf.gfile.GFile(os.path.join(dst_dir, 'best_eval.txt'), 'w') as f:
199
+ f.write('%s' % ckpt_eval)
200
+
201
+ # Update the best objective.
202
+ with tf.gfile.GFile(saved_objective_path, 'w') as f:
203
+ f.write('%f' % ckpt_objective)
204
+
205
+ logging.info('Copying checkpoint %s to %s', ckpt_path, dst_dir)
206
+ return True
207
+
208
+
209
+ def get_ema_vars():
210
+ """Get all exponential moving average (ema) variables."""
211
+ ema_vars = tf.trainable_variables() + tf.get_collection('moving_vars')
212
+ for v in tf.global_variables():
213
+ # We maintain mva for batch norm moving mean and variance as well.
214
+ if 'moving_mean' in v.name or 'moving_variance' in v.name:
215
+ ema_vars.append(v)
216
+ return list(set(ema_vars))
217
+
218
+
219
+ class DepthwiseConv2D(tf.keras.layers.DepthwiseConv2D, tf.layers.Layer):
220
+ """Wrap keras DepthwiseConv2D to tf.layers."""
221
+
222
+ pass
223
+
224
+
225
+ class EvalCkptDriver(object):
226
+ """A driver for running eval inference.
227
+
228
+ Attributes:
229
+ model_name: str. Model name to eval.
230
+ batch_size: int. Eval batch size.
231
+ image_size: int. Input image size, determined by model name.
232
+ num_classes: int. Number of classes, default to 1000 for ImageNet.
233
+ include_background_label: whether to include extra background label.
234
+ """
235
+
236
+ def __init__(self,
237
+ model_name,
238
+ batch_size=1,
239
+ image_size=224,
240
+ num_classes=1000,
241
+ include_background_label=False):
242
+ """Initialize internal variables."""
243
+ self.model_name = model_name
244
+ self.batch_size = batch_size
245
+ self.num_classes = num_classes
246
+ self.include_background_label = include_background_label
247
+ self.image_size = image_size
248
+
249
+ def restore_model(self, sess, ckpt_dir, enable_ema=True, export_ckpt=None):
250
+ """Restore variables from checkpoint dir."""
251
+ sess.run(tf.global_variables_initializer())
252
+ checkpoint = tf.train.latest_checkpoint(ckpt_dir)
253
+ if enable_ema:
254
+ ema = tf.train.ExponentialMovingAverage(decay=0.0)
255
+ ema_vars = get_ema_vars()
256
+ var_dict = ema.variables_to_restore(ema_vars)
257
+ ema_assign_op = ema.apply(ema_vars)
258
+ else:
259
+ var_dict = get_ema_vars()
260
+ ema_assign_op = None
261
+
262
+ tf.train.get_or_create_global_step()
263
+ sess.run(tf.global_variables_initializer())
264
+ saver = tf.train.Saver(var_dict, max_to_keep=1)
265
+ saver.restore(sess, checkpoint)
266
+
267
+ if export_ckpt:
268
+ if ema_assign_op is not None:
269
+ sess.run(ema_assign_op)
270
+ saver = tf.train.Saver(max_to_keep=1, save_relative_paths=True)
271
+ saver.save(sess, export_ckpt)
272
+
273
+ def build_model(self, features, is_training):
274
+ """Build model with input features."""
275
+ del features, is_training
276
+ raise ValueError('Must be implemented by subclasses.')
277
+
278
+ def get_preprocess_fn(self):
279
+ raise ValueError('Must be implemented by subclsses.')
280
+
281
+ def build_dataset(self, filenames, labels, is_training):
282
+ """Build input dataset."""
283
+ batch_drop_remainder = False
284
+ if 'condconv' in self.model_name and not is_training:
285
+ # CondConv layers can only be called with known batch dimension. Thus, we
286
+ # must drop all remaining examples that do not make up one full batch.
287
+ # To ensure all examples are evaluated, use a batch size that evenly
288
+ # divides the number of files.
289
+ batch_drop_remainder = True
290
+ num_files = len(filenames)
291
+ if num_files % self.batch_size != 0:
292
+ tf.logging.warn('Remaining examples in last batch are not being '
293
+ 'evaluated.')
294
+ filenames = tf.constant(filenames)
295
+ labels = tf.constant(labels)
296
+ dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
297
+
298
+ def _parse_function(filename, label):
299
+ image_string = tf.read_file(filename)
300
+ preprocess_fn = self.get_preprocess_fn()
301
+ image_decoded = preprocess_fn(
302
+ image_string, is_training, image_size=self.image_size)
303
+ image = tf.cast(image_decoded, tf.float32)
304
+ return image, label
305
+
306
+ dataset = dataset.map(_parse_function)
307
+ dataset = dataset.batch(self.batch_size,
308
+ drop_remainder=batch_drop_remainder)
309
+
310
+ iterator = dataset.make_one_shot_iterator()
311
+ images, labels = iterator.get_next()
312
+ return images, labels
313
+
314
+ def run_inference(self,
315
+ ckpt_dir,
316
+ image_files,
317
+ labels,
318
+ enable_ema=True,
319
+ export_ckpt=None):
320
+ """Build and run inference on the target images and labels."""
321
+ label_offset = 1 if self.include_background_label else 0
322
+ with tf.Graph().as_default(), tf.Session() as sess:
323
+ images, labels = self.build_dataset(image_files, labels, False)
324
+ probs = self.build_model(images, is_training=False)
325
+ if isinstance(probs, tuple):
326
+ probs = probs[0]
327
+
328
+ self.restore_model(sess, ckpt_dir, enable_ema, export_ckpt)
329
+
330
+ prediction_idx = []
331
+ prediction_prob = []
332
+ for _ in range(len(image_files) // self.batch_size):
333
+ out_probs = sess.run(probs)
334
+ idx = np.argsort(out_probs)[::-1]
335
+ prediction_idx.append(idx[:5] - label_offset)
336
+ prediction_prob.append([out_probs[pid] for pid in idx[:5]])
337
+
338
+ # Return the top 5 predictions (idx and prob) for each image.
339
+ return prediction_idx, prediction_prob
340
+
341
+ def eval_example_images(self,
342
+ ckpt_dir,
343
+ image_files,
344
+ labels_map_file,
345
+ enable_ema=True,
346
+ export_ckpt=None):
347
+ """Eval a list of example images.
348
+
349
+ Args:
350
+ ckpt_dir: str. Checkpoint directory path.
351
+ image_files: List[str]. A list of image file paths.
352
+ labels_map_file: str. The labels map file path.
353
+ enable_ema: enable expotential moving average.
354
+ export_ckpt: export ckpt folder.
355
+
356
+ Returns:
357
+ A tuple (pred_idx, and pred_prob), where pred_idx is the top 5 prediction
358
+ index and pred_prob is the top 5 prediction probability.
359
+ """
360
+ classes = json.loads(tf.gfile.Open(labels_map_file).read())
361
+ pred_idx, pred_prob = self.run_inference(
362
+ ckpt_dir, image_files, [0] * len(image_files), enable_ema, export_ckpt)
363
+ for i in range(len(image_files)):
364
+ print('predicted class for image {}: '.format(image_files[i]))
365
+ for j, idx in enumerate(pred_idx[i]):
366
+ print(' -> top_{} ({:4.2f}%): {} '.format(j, pred_prob[i][j] * 100,
367
+ classes[str(idx)]))
368
+ return pred_idx, pred_prob
369
+
370
+ def eval_imagenet(self, ckpt_dir, imagenet_eval_glob,
371
+ imagenet_eval_label, num_images, enable_ema, export_ckpt):
372
+ """Eval ImageNet images and report top1/top5 accuracy.
373
+
374
+ Args:
375
+ ckpt_dir: str. Checkpoint directory path.
376
+ imagenet_eval_glob: str. File path glob for all eval images.
377
+ imagenet_eval_label: str. File path for eval label.
378
+ num_images: int. Number of images to eval: -1 means eval the whole
379
+ dataset.
380
+ enable_ema: enable expotential moving average.
381
+ export_ckpt: export checkpoint folder.
382
+
383
+ Returns:
384
+ A tuple (top1, top5) for top1 and top5 accuracy.
385
+ """
386
+ imagenet_val_labels = [int(i) for i in tf.gfile.GFile(imagenet_eval_label)]
387
+ imagenet_filenames = sorted(tf.gfile.Glob(imagenet_eval_glob))
388
+ if num_images < 0:
389
+ num_images = len(imagenet_filenames)
390
+ image_files = imagenet_filenames[:num_images]
391
+ labels = imagenet_val_labels[:num_images]
392
+
393
+ pred_idx, _ = self.run_inference(
394
+ ckpt_dir, image_files, labels, enable_ema, export_ckpt)
395
+ top1_cnt, top5_cnt = 0.0, 0.0
396
+ for i, label in enumerate(labels):
397
+ top1_cnt += label in pred_idx[i][:1]
398
+ top5_cnt += label in pred_idx[i][:5]
399
+ if i % 100 == 0:
400
+ print('Step {}: top1_acc = {:4.2f}% top5_acc = {:4.2f}%'.format(
401
+ i, 100 * top1_cnt / (i + 1), 100 * top5_cnt / (i + 1)))
402
+ sys.stdout.flush()
403
+ top1, top5 = 100 * top1_cnt / num_images, 100 * top5_cnt / num_images
404
+ print('Final: top1_acc = {:4.2f}% top5_acc = {:4.2f}%'.format(top1, top5))
405
+ return top1, top5
extract_tracks_from_videos.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import yaml
4
+ import random
5
+ import pickle
6
+ import tqdm
7
+
8
+ import cv2
9
+ import numpy as np
10
+
11
+ from generate_aligned_tracks import ALIGNED_TRACKS_FILE_NAME
12
+
13
+ SEED = 0xDEADFACE
14
+ TRACK_LENGTH = 50
15
+ DETECTOR_STEP = 6
16
+ BOX_MULT = 1.5
17
+
18
+ TRACKS_ROOT = 'tracks'
19
+ BOXES_FILE_NAME = 'boxes.float32'
20
+
21
+
22
+ def main():
23
+ parser = argparse.ArgumentParser(description='Extracts tracks from videos')
24
+ parser.add_argument('--num_parts', type=int, default=1, help='Number of parts')
25
+ parser.add_argument('--part', type=int, default=0, help='Part')
26
+
27
+ args = parser.parse_args()
28
+
29
+ with open('config.yaml', 'r') as f:
30
+ config = yaml.load(f)
31
+
32
+ with open(os.path.join(config['ARTIFACTS_PATH'], ALIGNED_TRACKS_FILE_NAME), 'rb') as f:
33
+ aligned_tracks = pickle.load(f)
34
+
35
+ part_size = len(aligned_tracks) // args.num_parts + 1
36
+ assert part_size * args.num_parts >= len(aligned_tracks)
37
+ part_start = part_size * args.part
38
+ part_end = min(part_start + part_size, len(aligned_tracks))
39
+ print('Part {} ({}, {})'.format(args.part, part_start, part_end))
40
+
41
+ random.seed(SEED)
42
+ for real_video, fake_video, aligned_track in tqdm.tqdm(aligned_tracks[part_start:part_end]):
43
+ if len(aligned_track) < TRACK_LENGTH // DETECTOR_STEP:
44
+ continue
45
+ real_boxes = [item[1] for item in aligned_track]
46
+ fake_boxes = [item[2] for item in aligned_track]
47
+ start_idx = random.randint(0, len(aligned_track) - TRACK_LENGTH // DETECTOR_STEP)
48
+ start_frame = aligned_track[start_idx][0] * DETECTOR_STEP
49
+ middle_idx = start_idx + TRACK_LENGTH // DETECTOR_STEP // 2
50
+
51
+ if random.choice([False, True]):
52
+ xmin, ymin, xmax, ymax = real_boxes[middle_idx]
53
+ else:
54
+ xmin, ymin, xmax, ymax = fake_boxes[middle_idx]
55
+
56
+ width = xmax - xmin
57
+ height = ymax - ymin
58
+ xcenter = xmin + width / 2
59
+ ycenter = ymin + height / 2
60
+ width = width * BOX_MULT
61
+ height = height * BOX_MULT
62
+ xmin = xcenter - width / 2
63
+ ymin = ycenter - height / 2
64
+ xmax = xmin + width
65
+ ymax = ymin + height
66
+
67
+ for video, boxes in [(real_video, real_boxes), (fake_video, fake_boxes)]:
68
+ capture = cv2.VideoCapture(os.path.join(config['DFDC_DATA_PATH'], video))
69
+ frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
70
+ if frame_count == 0:
71
+ continue
72
+ frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
73
+ frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
74
+
75
+ xmin = max(int(xmin), 0)
76
+ xmax = min(int(xmax), frame_width)
77
+ ymin = max(int(ymin), 0)
78
+ ymax = min(int(ymax), frame_height)
79
+
80
+ dst_root = os.path.join(config['ARTIFACTS_PATH'], TRACKS_ROOT,
81
+ video + '_{}_{}_{}'.format(start_frame, xmin, ymin))
82
+ if os.path.exists(dst_root):
83
+ continue
84
+ os.makedirs(dst_root)
85
+ for i in range(start_frame + TRACK_LENGTH):
86
+ capture.grab()
87
+ if i < start_frame:
88
+ continue
89
+ ret, frame = capture.retrieve()
90
+ if not ret:
91
+ continue
92
+ face = frame[ymin:ymax, xmin:xmax]
93
+ dst_path = os.path.join(dst_root, '{}.png'.format(i - start_frame))
94
+ cv2.imwrite(dst_path, face)
95
+
96
+ boxes = np.array(boxes, dtype=np.float32)
97
+ boxes[:, 0] -= xmin
98
+ boxes[:, 1] -= ymin
99
+ boxes[:, 2] -= xmin
100
+ boxes[:, 3] -= ymin
101
+ boxes.tofile(os.path.join(dst_root, BOXES_FILE_NAME))
102
+
103
+
104
+ if __name__ == '__main__':
105
+ main()
generate_aligned_tracks.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import yaml
4
+ import json
5
+ from collections import defaultdict
6
+ import tqdm
7
+ import pickle
8
+
9
+ from tracker.utils import iou
10
+
11
+ from generate_tracks import TRACKS_FILE_NAME
12
+
13
+ MIN_TRACK_LENGTH = 5
14
+ IOU_THRESHOLD = 0.5
15
+ METADATA_FILE_NAME = 'metadata.json'
16
+
17
+ ALIGNED_TRACKS_FILE_NAME = 'aligned_tracks.pkl'
18
+
19
+
20
+ def get_track(tracks, min_track_length):
21
+ good_tracks = [track for track in tracks if len(track) >= min_track_length]
22
+ if len(good_tracks) == 1:
23
+ return good_tracks[0]
24
+ else:
25
+ return None
26
+
27
+
28
+ def main():
29
+ with open('config.yaml', 'r') as f:
30
+ config = yaml.load(f)
31
+
32
+ video_to_meta = {}
33
+
34
+ for path in glob.iglob(os.path.join(config['DFDC_DATA_PATH'], '**', METADATA_FILE_NAME), recursive=True):
35
+ root = os.path.basename(os.path.dirname(path))
36
+ with open(path, 'r') as f:
37
+ for video, meta in json.load(f).items():
38
+ video_to_meta[os.path.join(root, video)] = meta
39
+
40
+ real_video_to_fake_videos = defaultdict(list)
41
+ for video in video_to_meta:
42
+ root = os.path.dirname(video)
43
+ meta = video_to_meta[video]
44
+ if meta['label'] == 'FAKE':
45
+ original_video = os.path.join(root, meta['original'])
46
+ real_video_to_fake_videos[original_video].append(video)
47
+
48
+ print('Total number of real videos: {}'.format(len(real_video_to_fake_videos)))
49
+ print('Total number of fake videos: {}'.format(sum([len(fake_videos) for fake_videos in real_video_to_fake_videos.items()])))
50
+
51
+ with open(os.path.join(config['ARTIFACTS_PATH'], TRACKS_FILE_NAME), 'rb') as f:
52
+ video_to_tracks = pickle.load(f)
53
+
54
+ real_fake_aligned_tracks = []
55
+ real_videos = sorted(real_video_to_fake_videos)
56
+ for real_video in tqdm.tqdm(real_videos):
57
+ if real_video not in video_to_tracks:
58
+ continue
59
+ real_tracks = [track for track in video_to_tracks[real_video] if len(track) >= MIN_TRACK_LENGTH]
60
+
61
+ for fake_video in real_video_to_fake_videos[real_video]:
62
+ if fake_video not in video_to_tracks:
63
+ continue
64
+ fake_tracks = [track for track in video_to_tracks[fake_video] if len(track) >= MIN_TRACK_LENGTH]
65
+
66
+ for real_track in real_tracks:
67
+ real_frame_idx_to_bbox = {}
68
+ for real_frame_idx, real_bbox in real_track:
69
+ real_frame_idx_to_bbox[real_frame_idx] = real_bbox
70
+
71
+ for fake_track in fake_tracks:
72
+ fake_frame_idx_to_bbox = {}
73
+ ious = []
74
+ for fake_frame_idx, fake_bbox in fake_track:
75
+ fake_frame_idx_to_bbox[fake_frame_idx] = fake_bbox
76
+ if fake_frame_idx in real_frame_idx_to_bbox:
77
+ real_bbox = real_frame_idx_to_bbox[fake_frame_idx]
78
+ ious.append(iou(real_bbox, fake_bbox))
79
+ if len(ious) > 0 and min(ious) > IOU_THRESHOLD:
80
+ start_frame_idx = max(min(real_frame_idx_to_bbox), min(fake_frame_idx_to_bbox))
81
+ end_frame_idx = min(max(real_frame_idx_to_bbox), max(fake_frame_idx_to_bbox)) + 1
82
+ assert start_frame_idx < end_frame_idx
83
+ real_fake_aligned_track = []
84
+ for frame_idx in range(start_frame_idx, end_frame_idx):
85
+ real_bbox = real_frame_idx_to_bbox[frame_idx]
86
+ fake_bbox = fake_frame_idx_to_bbox[frame_idx]
87
+ assert iou(real_bbox, fake_bbox) > IOU_THRESHOLD
88
+ real_fake_aligned_track.append((frame_idx, real_bbox, fake_bbox))
89
+ real_fake_aligned_tracks.append((real_video, fake_video, real_fake_aligned_track))
90
+ break
91
+
92
+ print('Total number of tracks: {}'.format(len(real_fake_aligned_tracks)))
93
+
94
+ with open(os.path.join(config['ARTIFACTS_PATH'], ALIGNED_TRACKS_FILE_NAME), 'wb') as f:
95
+ pickle.dump(real_fake_aligned_tracks, f)
96
+
97
+
98
+ if __name__ == '__main__':
99
+ main()
generate_track_pairs.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import os
3
+ import json
4
+ from collections import defaultdict
5
+ import glob
6
+
7
+ from generate_aligned_tracks import METADATA_FILE_NAME
8
+ from extract_tracks_from_videos import TRACKS_ROOT
9
+
10
+ TRACK_PAIRS_FILE_NAME = 'track_pairs.txt'
11
+
12
+
13
+ def main():
14
+ with open('config.yaml', 'r') as f:
15
+ config = yaml.load(f)
16
+
17
+ video_to_tracks = defaultdict(list)
18
+
19
+ for path in glob.iglob(os.path.join(config['ARTIFACTS_PATH'], TRACKS_ROOT, 'dfdc_train_part_*', '*.mp4_*')):
20
+ parts = path.split('/')
21
+ rel_path = '/'.join(parts[-2:])
22
+ video = '_'.join(rel_path.split('_')[:-3])
23
+ video_to_tracks[video].append(rel_path)
24
+
25
+ video_to_meta = {}
26
+
27
+ for path in glob.iglob(os.path.join(config['DFDC_DATA_PATH'], '**', METADATA_FILE_NAME), recursive=True):
28
+ root = os.path.basename(os.path.dirname(path))
29
+ with open(path, 'r') as f:
30
+ for video, meta in json.load(f).items():
31
+ video_to_meta[os.path.join(root, video)] = meta
32
+
33
+ fake_video_to_real_video = {}
34
+ for video in video_to_meta:
35
+ root = os.path.dirname(video)
36
+ meta = video_to_meta[video]
37
+ if meta['label'] == 'FAKE':
38
+ original_video = os.path.join(root, meta['original'])
39
+ fake_video_to_real_video[video] = original_video
40
+
41
+ print('Total number of fake videos: {}'.format(len(fake_video_to_real_video)))
42
+
43
+ track_pairs = []
44
+
45
+ fake_videos = sorted(fake_video_to_real_video)
46
+ for fake_video in fake_videos:
47
+ real_video = fake_video_to_real_video[fake_video]
48
+ fake_tracks = video_to_tracks[fake_video]
49
+ real_tracks = video_to_tracks[real_video]
50
+
51
+ for fake_track in fake_tracks:
52
+ if not os.path.exists(os.path.join(config['ARTIFACTS_PATH'], TRACKS_ROOT, fake_track, '0.png')):
53
+ continue
54
+ suffix = fake_track[len(fake_video):]
55
+ for real_track in real_tracks:
56
+ if not os.path.exists(os.path.join(config['ARTIFACTS_PATH'], TRACKS_ROOT, real_track, '0.png')):
57
+ continue
58
+ if real_track.endswith(suffix):
59
+ track_pairs.append((real_track, fake_track))
60
+ break
61
+
62
+ print('Total number of track pairs: {}'.format(len(track_pairs)))
63
+
64
+ with open(os.path.join(config['ARTIFACTS_PATH'], TRACK_PAIRS_FILE_NAME), 'w') as f:
65
+ for real_track, fake_track in track_pairs:
66
+ f.write('{},{}\n'.format(real_track, fake_track))
67
+
68
+
69
+ if __name__ == '__main__':
70
+ main()
generate_tracks.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import tqdm
4
+ import glob
5
+ import pickle
6
+
7
+ from tracker.iou_tracker import track_iou
8
+ from detect_faces_on_videos import DETECTIONS_FILE_NAME, DETECTIONS_ROOT
9
+
10
+ SIGMA_L = 0.3
11
+ SIGMA_H = 0.9
12
+ SIGMA_IOU = 0.3
13
+ T_MIN = 1
14
+
15
+ TRACKS_FILE_NAME = 'tracks.pkl'
16
+
17
+
18
+ def get_tracks(detections):
19
+ if len(detections) == 0:
20
+ return []
21
+
22
+ converted_detections = []
23
+ for i, detections_per_frame in enumerate(detections):
24
+ converted_detections_per_frame = []
25
+ for j, (bbox, score) in enumerate(zip(detections_per_frame['boxes'], detections_per_frame['scores'])):
26
+ bbox = tuple(bbox.tolist())
27
+ converted_detections_per_frame.append({'bbox': bbox, 'score': score})
28
+ converted_detections.append(converted_detections_per_frame)
29
+
30
+ tracks = track_iou(converted_detections, SIGMA_L, SIGMA_H, SIGMA_IOU, T_MIN)
31
+ tracks_converted = []
32
+ for track in tracks:
33
+ track_converted = []
34
+ start_frame = track['start_frame'] - 1
35
+ for i, bbox in enumerate(track['bboxes']):
36
+ track_converted.append((start_frame + i, bbox))
37
+ tracks_converted.append(track_converted)
38
+
39
+ return tracks_converted
40
+
41
+
42
+ def main():
43
+ with open('config.yaml', 'r') as f:
44
+ config = yaml.load(f)
45
+
46
+ root_dir = os.path.join(config['ARTIFACTS_PATH'], DETECTIONS_ROOT)
47
+ detections_content = []
48
+ for path in glob.iglob(os.path.join(root_dir, '**', DETECTIONS_FILE_NAME), recursive=True):
49
+ rel_path = path[len(root_dir) + 1:]
50
+ detections_content.append(rel_path)
51
+
52
+ detections_content = sorted(detections_content)
53
+ print('Total number of videos: {}'.format(len(detections_content)))
54
+
55
+ video_to_tracks = {}
56
+ for rel_path in tqdm.tqdm(detections_content):
57
+ video = os.path.dirname(rel_path)
58
+ with open(os.path.join(root_dir, rel_path), 'rb') as f:
59
+ detections = pickle.load(f)
60
+ video_to_tracks[video] = get_tracks(detections)
61
+
62
+ track_count = sum([len(tracks) for tracks in video_to_tracks.values()])
63
+ print('Total number of tracks: {}'.format(track_count))
64
+
65
+ with open(os.path.join(config['ARTIFACTS_PATH'], TRACKS_FILE_NAME), 'wb') as f:
66
+ pickle.dump(video_to_tracks, f)
67
+
68
+
69
+ if __name__ == '__main__':
70
+ main()
images/augmented_mixup.jpg ADDED

Git LFS Details

  • SHA256: 99db3c3344ef4c7635fc0134f9b6dc73be249f7ac9291af9a003e7233d72ad7b
  • Pointer size: 131 Bytes
  • Size of remote file: 116 kB
images/clip_example.jpg ADDED
images/first_and_second_model_inputs.jpg ADDED
images/mixup_example.jpg ADDED

Git LFS Details

  • SHA256: 74358172ce92a6ffca0796583f921e0158b933658352a990e38179949d47ae1f
  • Pointer size: 131 Bytes
  • Size of remote file: 118 kB
images/pred_transform.jpg ADDED
images/third_model_input.jpg ADDED
models/.gitkeep ADDED
File without changes
predict.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import glob
4
+
5
+ import numpy as np
6
+ import cv2
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch.utils.data import Dataset, DataLoader
11
+
12
+ from torchvision.models.detection.transform import GeneralizedRCNNTransform
13
+
14
+ from albumentations import Compose, SmallestMaxSize, CenterCrop, Normalize, PadIfNeeded
15
+ from albumentations.pytorch import ToTensor
16
+
17
+ from dsfacedetector.face_ssd_infer import SSD
18
+ from tracker.iou_tracker import track_iou
19
+ from efficientnet_pytorch.model import EfficientNet, MBConvBlock
20
+
21
+ DETECTOR_WEIGHTS_PATH = 'WIDERFace_DSFD_RES152.fp16.pth'
22
+ DETECTOR_THRESHOLD = 0.3
23
+ DETECTOR_MIN_SIZE = 512
24
+ DETECTOR_MAX_SIZE = 512
25
+ DETECTOR_MEAN = (104.0, 117.0, 123.0)
26
+ DETECTOR_STD = (1.0, 1.0, 1.0)
27
+ DETECTOR_BATCH_SIZE = 16
28
+ DETECTOR_STEP = 3
29
+
30
+ TRACKER_SIGMA_L = 0.3
31
+ TRACKER_SIGMA_H = 0.9
32
+ TRACKER_SIGMA_IOU = 0.3
33
+ TRACKER_T_MIN = 7
34
+
35
+ VIDEO_MODEL_BBOX_MULT = 1.5
36
+ VIDEO_MODEL_MIN_SIZE = 224
37
+ VIDEO_MODEL_CROP_HEIGHT = 224
38
+ VIDEO_MODEL_CROP_WIDTH = 192
39
+ VIDEO_FACE_MODEL_TRACK_STEP = 2
40
+ VIDEO_SEQUENCE_MODEL_SEQUENCE_LENGTH = 7
41
+ VIDEO_SEQUENCE_MODEL_TRACK_STEP = 14
42
+
43
+ VIDEO_SEQUENCE_MODEL_WEIGHTS_PATH = 'efficientnet-b7_ns_seq_aa-original-mstd0.5_100k_v4_cad79a/snapshot_100000.fp16.pth'
44
+ FIRST_VIDEO_FACE_MODEL_WEIGHTS_PATH = 'efficientnet-b7_ns_aa-original-mstd0.5_large_crop_100k_v4_cad79a/snapshot_100000.fp16.pth'
45
+ SECOND_VIDEO_FACE_MODEL_WEIGHTS_PATH = 'efficientnet-b7_ns_aa-original-mstd0.5_re_100k_v4_cad79a/snapshot_100000.fp16.pth'
46
+
47
+ VIDEO_BATCH_SIZE = 1
48
+ VIDEO_TARGET_FPS = 15
49
+ VIDEO_NUM_WORKERS = 0
50
+
51
+
52
+ class UnlabeledVideoDataset(Dataset):
53
+ def __init__(self, root_dir, content=None):
54
+ self.root_dir = os.path.normpath(root_dir)
55
+ if content is not None:
56
+ self.content = content
57
+ else:
58
+ self.content = []
59
+ for path in glob.iglob(os.path.join(self.root_dir, '**', '*.mp4'), recursive=True):
60
+ rel_path = path[len(self.root_dir) + 1:]
61
+ self.content.append(rel_path)
62
+ self.content = sorted(self.content)
63
+
64
+ def __len__(self):
65
+ return len(self.content)
66
+
67
+ def __getitem__(self, idx):
68
+ rel_path = self.content[idx]
69
+ path = os.path.join(self.root_dir, rel_path)
70
+
71
+ sample = {
72
+ 'frames': [],
73
+ 'index': idx
74
+ }
75
+
76
+ capture = cv2.VideoCapture(path)
77
+ frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
78
+ if frame_count == 0:
79
+ return sample
80
+
81
+ fps = int(capture.get(cv2.CAP_PROP_FPS))
82
+ video_step = round(fps / VIDEO_TARGET_FPS)
83
+ if video_step == 0:
84
+ return sample
85
+
86
+ for i in range(frame_count):
87
+ capture.grab()
88
+ if i % video_step != 0:
89
+ continue
90
+ ret, frame = capture.retrieve()
91
+ if not ret:
92
+ continue
93
+
94
+ sample['frames'].append(frame)
95
+
96
+ return sample
97
+
98
+
99
+ class Detector(object):
100
+ def __init__(self, weights_path):
101
+ self.model = SSD('test')
102
+ self.model.cuda().eval()
103
+
104
+ state = torch.load(weights_path, map_location=lambda storage, loc: storage)
105
+ state = {key: value.float() for key, value in state.items()}
106
+ self.model.load_state_dict(state)
107
+
108
+ self.transform = GeneralizedRCNNTransform(DETECTOR_MIN_SIZE, DETECTOR_MAX_SIZE, DETECTOR_MEAN, DETECTOR_STD)
109
+ self.transform.eval()
110
+
111
+ def detect(self, images):
112
+ images = torch.stack([torch.from_numpy(image).cuda() for image in images])
113
+ images = images.transpose(1, 3).transpose(2, 3).float()
114
+ original_image_sizes = [img.shape[-2:] for img in images]
115
+ images, _ = self.transform(images, None)
116
+ with torch.no_grad():
117
+ detections_batch = self.model(images.tensors).cpu().numpy()
118
+ result = []
119
+ for detections, image_size in zip(detections_batch, images.image_sizes):
120
+ scores = detections[1, :, 0]
121
+ keep_idxs = scores > DETECTOR_THRESHOLD
122
+ detections = detections[1, keep_idxs, :]
123
+ detections = detections[:, [1, 2, 3, 4, 0]]
124
+ detections[:, 0] *= image_size[1]
125
+ detections[:, 1] *= image_size[0]
126
+ detections[:, 2] *= image_size[1]
127
+ detections[:, 3] *= image_size[0]
128
+ result.append({
129
+ 'scores': torch.from_numpy(detections[:, 4]),
130
+ 'boxes': torch.from_numpy(detections[:, :4])
131
+ })
132
+
133
+ result = self.transform.postprocess(result, images.image_sizes, original_image_sizes)
134
+ return result
135
+
136
+
137
+ def get_tracks(detections):
138
+ if len(detections) == 0:
139
+ return []
140
+
141
+ converted_detections = []
142
+ frame_bbox_to_face_idx = {}
143
+ for i, detections_per_frame in enumerate(detections):
144
+ converted_detections_per_frame = []
145
+ for j, (bbox, score) in enumerate(zip(detections_per_frame['boxes'], detections_per_frame['scores'])):
146
+ bbox = tuple(bbox.tolist())
147
+ frame_bbox_to_face_idx[(i, bbox)] = j
148
+ converted_detections_per_frame.append({'bbox': bbox, 'score': score})
149
+ converted_detections.append(converted_detections_per_frame)
150
+
151
+ tracks = track_iou(converted_detections, TRACKER_SIGMA_L, TRACKER_SIGMA_H, TRACKER_SIGMA_IOU, TRACKER_T_MIN)
152
+ tracks_converted = []
153
+ for track in tracks:
154
+ start_frame = track['start_frame'] - 1
155
+ bboxes = np.array(track['bboxes'], dtype=np.float32)
156
+ frame_indices = np.arange(start_frame, start_frame + len(bboxes)) * DETECTOR_STEP
157
+ interp_frame_indices = np.arange(frame_indices[0], frame_indices[-1] + 1)
158
+ interp_bboxes = np.zeros((len(interp_frame_indices), 4), dtype=np.float32)
159
+ for i in range(4):
160
+ interp_bboxes[:, i] = np.interp(interp_frame_indices, frame_indices, bboxes[:, i])
161
+
162
+ track_converted = []
163
+ for frame_idx, bbox in zip(interp_frame_indices, interp_bboxes):
164
+ track_converted.append((frame_idx, bbox))
165
+ tracks_converted.append(track_converted)
166
+
167
+ return tracks_converted
168
+
169
+
170
+ class SeqExpandConv(nn.Module):
171
+ def __init__(self, in_channels, out_channels, seq_length):
172
+ super(SeqExpandConv, self).__init__()
173
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=(3, 1, 1), padding=(1, 0, 0), bias=False)
174
+ self.seq_length = seq_length
175
+
176
+ def forward(self, x):
177
+ batch_size, in_channels, height, width = x.shape
178
+ x = x.view(batch_size // self.seq_length, self.seq_length, in_channels, height, width)
179
+ x = self.conv(x.transpose(1, 2).contiguous()).transpose(2, 1).contiguous()
180
+ x = x.flatten(0, 1)
181
+ return x
182
+
183
+
184
+ class TrackSequencesClassifier(object):
185
+ def __init__(self, weights_path):
186
+ model = EfficientNet.from_name('efficientnet-b7', override_params={'num_classes': 1})
187
+
188
+ for module in model.modules():
189
+ if isinstance(module, MBConvBlock):
190
+ if module._block_args.expand_ratio != 1:
191
+ expand_conv = module._expand_conv
192
+ seq_expand_conv = SeqExpandConv(expand_conv.in_channels, expand_conv.out_channels,
193
+ VIDEO_SEQUENCE_MODEL_SEQUENCE_LENGTH)
194
+ module._expand_conv = seq_expand_conv
195
+ self.model = model.cuda().eval()
196
+
197
+ normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
198
+ self.transform = Compose(
199
+ [SmallestMaxSize(VIDEO_MODEL_MIN_SIZE), CenterCrop(VIDEO_MODEL_CROP_HEIGHT, VIDEO_MODEL_CROP_WIDTH),
200
+ normalize, ToTensor()])
201
+
202
+ state = torch.load(weights_path, map_location=lambda storage, loc: storage)
203
+ state = {key: value.float() for key, value in state.items()}
204
+ self.model.load_state_dict(state)
205
+
206
+ def classify(self, track_sequences):
207
+ track_sequences = [torch.stack([self.transform(image=face)['image'] for face in sequence]) for sequence in
208
+ track_sequences]
209
+ track_sequences = torch.cat(track_sequences).cuda()
210
+ with torch.no_grad():
211
+ track_probs = torch.sigmoid(self.model(track_sequences)).flatten().cpu().numpy()
212
+
213
+ return track_probs
214
+
215
+
216
+ class TrackFacesClassifier(object):
217
+ def __init__(self, first_weights_path, second_weights_path):
218
+ first_model = EfficientNet.from_name('efficientnet-b7', override_params={'num_classes': 1})
219
+ self.first_model = first_model.cuda().eval()
220
+ second_model = EfficientNet.from_name('efficientnet-b7', override_params={'num_classes': 1})
221
+ self.second_model = second_model.cuda().eval()
222
+
223
+ first_normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
224
+ self.first_transform = Compose(
225
+ [SmallestMaxSize(VIDEO_MODEL_CROP_WIDTH), PadIfNeeded(VIDEO_MODEL_CROP_HEIGHT, VIDEO_MODEL_CROP_WIDTH),
226
+ CenterCrop(VIDEO_MODEL_CROP_HEIGHT, VIDEO_MODEL_CROP_WIDTH), first_normalize, ToTensor()])
227
+
228
+ second_normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
229
+ self.second_transform = Compose(
230
+ [SmallestMaxSize(VIDEO_MODEL_MIN_SIZE), CenterCrop(VIDEO_MODEL_CROP_HEIGHT, VIDEO_MODEL_CROP_WIDTH),
231
+ second_normalize, ToTensor()])
232
+
233
+ first_state = torch.load(first_weights_path, map_location=lambda storage, loc: storage)
234
+ first_state = {key: value.float() for key, value in first_state.items()}
235
+ self.first_model.load_state_dict(first_state)
236
+
237
+ second_state = torch.load(second_weights_path, map_location=lambda storage, loc: storage)
238
+ second_state = {key: value.float() for key, value in second_state.items()}
239
+ self.second_model.load_state_dict(second_state)
240
+
241
+ def classify(self, track_faces):
242
+ first_track_faces = []
243
+ second_track_faces = []
244
+ for i, face in enumerate(track_faces):
245
+ if i % 4 < 2:
246
+ first_track_faces.append(self.first_transform(image=face)['image'])
247
+ else:
248
+ second_track_faces.append(self.second_transform(image=face)['image'])
249
+ first_track_faces = torch.stack(first_track_faces).cuda()
250
+ second_track_faces = torch.stack(second_track_faces).cuda()
251
+ with torch.no_grad():
252
+ first_track_probs = torch.sigmoid(self.first_model(first_track_faces)).flatten().cpu().numpy()
253
+ second_track_probs = torch.sigmoid(self.second_model(second_track_faces)).flatten().cpu().numpy()
254
+ track_probs = np.concatenate((first_track_probs, second_track_probs))
255
+
256
+ return track_probs
257
+
258
+
259
+ def extract_sequence(frames, start_idx, bbox, flip):
260
+ frame_height, frame_width, _ = frames[start_idx].shape
261
+ xmin, ymin, xmax, ymax = bbox
262
+ width = xmax - xmin
263
+ height = ymax - ymin
264
+ xcenter = xmin + width / 2
265
+ ycenter = ymin + height / 2
266
+ width = width * VIDEO_MODEL_BBOX_MULT
267
+ height = height * VIDEO_MODEL_BBOX_MULT
268
+ xmin = xcenter - width / 2
269
+ ymin = ycenter - height / 2
270
+ xmax = xmin + width
271
+ ymax = ymin + height
272
+
273
+ xmin = max(int(xmin), 0)
274
+ xmax = min(int(xmax), frame_width)
275
+ ymin = max(int(ymin), 0)
276
+ ymax = min(int(ymax), frame_height)
277
+
278
+ sequence = []
279
+ for i in range(VIDEO_SEQUENCE_MODEL_SEQUENCE_LENGTH):
280
+ face = cv2.cvtColor(frames[start_idx + i][ymin:ymax, xmin:xmax], cv2.COLOR_BGR2RGB)
281
+ sequence.append(face)
282
+
283
+ if flip:
284
+ sequence = [face[:, ::-1] for face in sequence]
285
+
286
+ return sequence
287
+
288
+
289
+ def extract_face(frame, bbox, flip):
290
+ frame_height, frame_width, _ = frame.shape
291
+ xmin, ymin, xmax, ymax = bbox
292
+ width = xmax - xmin
293
+ height = ymax - ymin
294
+ xcenter = xmin + width / 2
295
+ ycenter = ymin + height / 2
296
+ width = width * VIDEO_MODEL_BBOX_MULT
297
+ height = height * VIDEO_MODEL_BBOX_MULT
298
+ xmin = xcenter - width / 2
299
+ ymin = ycenter - height / 2
300
+ xmax = xmin + width
301
+ ymax = ymin + height
302
+
303
+ xmin = max(int(xmin), 0)
304
+ xmax = min(int(xmax), frame_width)
305
+ ymin = max(int(ymin), 0)
306
+ ymax = min(int(ymax), frame_height)
307
+
308
+ face = cv2.cvtColor(frame[ymin:ymax, xmin:xmax], cv2.COLOR_BGR2RGB)
309
+ if flip:
310
+ face = face[:, ::-1].copy()
311
+
312
+ return face
313
+
314
+
315
+ def main():
316
+ with open('config.yaml', 'r') as f:
317
+ config = yaml.load(f)
318
+
319
+ detector = Detector(os.path.join(config['MODELS_PATH'], DETECTOR_WEIGHTS_PATH))
320
+ track_sequences_classifier = TrackSequencesClassifier(os.path.join(config['MODELS_PATH'], VIDEO_SEQUENCE_MODEL_WEIGHTS_PATH))
321
+ track_faces_classifier = TrackFacesClassifier(os.path.join(config['MODELS_PATH'], FIRST_VIDEO_FACE_MODEL_WEIGHTS_PATH),
322
+ os.path.join(config['MODELS_PATH'], SECOND_VIDEO_FACE_MODEL_WEIGHTS_PATH))
323
+
324
+ dataset = UnlabeledVideoDataset(os.path.join(config['DFDC_DATA_PATH'], 'test_videos'))
325
+ print('Total number of videos: {}'.format(len(dataset)))
326
+
327
+ loader = DataLoader(dataset, batch_size=VIDEO_BATCH_SIZE, shuffle=False, num_workers=VIDEO_NUM_WORKERS,
328
+ collate_fn=lambda X: X,
329
+ drop_last=False)
330
+
331
+ video_name_to_score = {}
332
+
333
+ for video_sample in loader:
334
+ frames = video_sample[0]['frames']
335
+ detector_frames = frames[::DETECTOR_STEP]
336
+ video_idx = video_sample[0]['index']
337
+ video_rel_path = dataset.content[video_idx]
338
+ video_name = os.path.basename(video_rel_path)
339
+
340
+ if len(frames) == 0:
341
+ video_name_to_score[video_name] = 0.5
342
+ continue
343
+
344
+ detections = []
345
+ for start in range(0, len(detector_frames), DETECTOR_BATCH_SIZE):
346
+ end = min(len(detector_frames), start + DETECTOR_BATCH_SIZE)
347
+ detections_batch = detector.detect(detector_frames[start:end])
348
+ for detections_per_frame in detections_batch:
349
+ detections.append({key: value.cpu().numpy() for key, value in detections_per_frame.items()})
350
+
351
+ tracks = get_tracks(detections)
352
+ if len(tracks) == 0:
353
+ video_name_to_score[video_name] = 0.5
354
+ continue
355
+
356
+ sequence_track_scores = []
357
+ for track in tracks:
358
+ track_sequences = []
359
+ for i, (start_idx, _) in enumerate(
360
+ track[:-VIDEO_SEQUENCE_MODEL_SEQUENCE_LENGTH + 1:VIDEO_SEQUENCE_MODEL_TRACK_STEP]):
361
+ assert start_idx >= 0 and start_idx + VIDEO_SEQUENCE_MODEL_SEQUENCE_LENGTH <= len(frames)
362
+ _, bbox = track[i * VIDEO_SEQUENCE_MODEL_TRACK_STEP + VIDEO_SEQUENCE_MODEL_SEQUENCE_LENGTH // 2]
363
+ track_sequences.append(extract_sequence(frames, start_idx, bbox, i % 2 == 0))
364
+ sequence_track_scores.append(track_sequences_classifier.classify(track_sequences))
365
+
366
+ face_track_scores = []
367
+ for track in tracks:
368
+ track_faces = []
369
+ for i, (frame_idx, bbox) in enumerate(track[::VIDEO_FACE_MODEL_TRACK_STEP]):
370
+ face = extract_face(frames[frame_idx], bbox, i % 2 == 0)
371
+ track_faces.append(face)
372
+ face_track_scores.append(track_faces_classifier.classify(track_faces))
373
+
374
+ sequence_track_scores = np.concatenate(sequence_track_scores)
375
+ face_track_scores = np.concatenate(face_track_scores)
376
+ track_probs = np.concatenate((sequence_track_scores, face_track_scores))
377
+
378
+ delta = track_probs - 0.5
379
+ sign = np.sign(delta)
380
+ pos_delta = delta > 0
381
+ neg_delta = delta < 0
382
+ track_probs[pos_delta] = np.clip(0.5 + sign[pos_delta] * np.power(abs(delta[pos_delta]), 0.65), 0.01, 0.99)
383
+ track_probs[neg_delta] = np.clip(0.5 + sign[neg_delta] * np.power(abs(delta[neg_delta]), 0.65), 0.01, 0.99)
384
+ weights = np.power(abs(delta), 1.0) + 1e-4
385
+ video_score = float((track_probs * weights).sum() / weights.sum())
386
+
387
+ video_name_to_score[video_name] = video_score
388
+ print('NUM DETECTION FRAMES: {}, VIDEO SCORE: {}. {}'.format(len(detections), video_name_to_score[video_name],
389
+ video_rel_path))
390
+
391
+ os.makedirs(os.path.dirname(config['SUBMISSION_PATH']), exist_ok=True)
392
+ with open(config['SUBMISSION_PATH'], 'w') as f:
393
+ f.write('filename,label\n')
394
+ for video_name in sorted(video_name_to_score):
395
+ score = video_name_to_score[video_name]
396
+ f.write('{},{}\n'.format(video_name, score))
397
+
398
+
399
+ main()
tracker/__init__.py ADDED
File without changes
tracker/iou_tracker.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Source: https://github.com/bochinski/iou-tracker
2
+
3
+ from .utils import iou
4
+
5
+
6
+ def track_iou(detections, sigma_l, sigma_h, sigma_iou, t_min):
7
+ """
8
+ Simple IOU based tracker.
9
+ See "High-Speed Tracking-by-Detection Without Using Image Information by E. Bochinski, V. Eiselein, T. Sikora" for
10
+ more information.
11
+
12
+ Args:
13
+ detections (list): list of detections per frame, usually generated by util.load_mot
14
+ sigma_l (float): low detection threshold.
15
+ sigma_h (float): high detection threshold.
16
+ sigma_iou (float): IOU threshold.
17
+ t_min (float): minimum track length in frames.
18
+
19
+ Returns:
20
+ list: list of tracks.
21
+ """
22
+
23
+ tracks_active = []
24
+ tracks_finished = []
25
+
26
+ for frame_num, detections_frame in enumerate(detections, start=1):
27
+ # apply low threshold to detections
28
+ dets = [det for det in detections_frame if det['score'] >= sigma_l]
29
+
30
+ updated_tracks = []
31
+ for track in tracks_active:
32
+ if len(dets) > 0:
33
+ # get det with highest iou
34
+ best_match = max(dets, key=lambda x: iou(track['bboxes'][-1], x['bbox']))
35
+ if iou(track['bboxes'][-1], best_match['bbox']) >= sigma_iou:
36
+ track['bboxes'].append(best_match['bbox'])
37
+ track['max_score'] = max(track['max_score'], best_match['score'])
38
+
39
+ updated_tracks.append(track)
40
+
41
+ # remove from best matching detection from detections
42
+ del dets[dets.index(best_match)]
43
+
44
+ # if track was not updated
45
+ if len(updated_tracks) == 0 or track is not updated_tracks[-1]:
46
+ # finish track when the conditions are met
47
+ if track['max_score'] >= sigma_h and len(track['bboxes']) >= t_min:
48
+ tracks_finished.append(track)
49
+
50
+ # create new tracks
51
+ new_tracks = [{'bboxes': [det['bbox']], 'max_score': det['score'], 'start_frame': frame_num} for det in dets]
52
+ tracks_active = updated_tracks + new_tracks
53
+
54
+ # finish all remaining active tracks
55
+ tracks_finished += [track for track in tracks_active
56
+ if track['max_score'] >= sigma_h and len(track['bboxes']) >= t_min]
57
+
58
+ return tracks_finished
tracker/utils.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def iou(bbox1, bbox2):
2
+ """
3
+ Calculates the intersection-over-union of two bounding boxes.
4
+
5
+ Args:
6
+ bbox1 (numpy.array, list of floats): bounding box in format x1,y1,x2,y2.
7
+ bbox2 (numpy.array, list of floats): bounding box in format x1,y1,x2,y2.
8
+
9
+ Returns:
10
+ int: intersection-over-onion of bbox1, bbox2
11
+ """
12
+
13
+ bbox1 = [float(x) for x in bbox1]
14
+ bbox2 = [float(x) for x in bbox2]
15
+
16
+ (x0_1, y0_1, x1_1, y1_1) = bbox1
17
+ (x0_2, y0_2, x1_2, y1_2) = bbox2
18
+
19
+ # get the overlap rectangle
20
+ overlap_x0 = max(x0_1, x0_2)
21
+ overlap_y0 = max(y0_1, y0_2)
22
+ overlap_x1 = min(x1_1, x1_2)
23
+ overlap_y1 = min(y1_1, y1_2)
24
+
25
+ # check if there is an overlap
26
+ if overlap_x1 - overlap_x0 <= 0 or overlap_y1 - overlap_y0 <= 0:
27
+ return 0
28
+
29
+ # if yes, calculate the ratio of the overlap to each ROI size and the unified size
30
+ size_1 = (x1_1 - x0_1) * (y1_1 - y0_1)
31
+ size_2 = (x1_2 - x0_2) * (y1_2 - y0_2)
32
+ size_intersection = (overlap_x1 - overlap_x0) * (overlap_y1 - overlap_y0)
33
+ size_union = size_1 + size_2 - size_intersection
34
+
35
+ return size_intersection / size_union
train_b7_ns_aa_original_large_crop_100k.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import os
3
+ import random
4
+ import tqdm
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ import torch
10
+ from torch import distributions
11
+ from torch.nn import functional as F
12
+ from torch.utils.data import DataLoader
13
+
14
+ from torch.utils.tensorboard import SummaryWriter
15
+
16
+ import ffmpeg
17
+
18
+ from albumentations import ImageOnlyTransform
19
+ from albumentations import SmallestMaxSize, PadIfNeeded, HorizontalFlip, Normalize, Compose, RandomCrop
20
+ from albumentations.pytorch import ToTensor
21
+ from efficientnet_pytorch import EfficientNet
22
+
23
+ from timm.data.transforms_factory import transforms_imagenet_train
24
+
25
+ from datasets import TrackPairDataset
26
+ from extract_tracks_from_videos import TRACK_LENGTH, TRACKS_ROOT
27
+ from generate_track_pairs import TRACK_PAIRS_FILE_NAME
28
+
29
+ SEED = 30
30
+ BATCH_SIZE = 8
31
+ TRAIN_INDICES = [9, 13, 17, 21, 25, 29, 33, 37]
32
+ INITIAL_LR = 0.005
33
+ MOMENTUM = 0.9
34
+ WEIGHT_DECAY = 1e-4
35
+ NUM_WORKERS = 8
36
+ NUM_WARMUP_ITERATIONS = 100
37
+ SNAPSHOT_FREQUENCY = 1000
38
+ OUTPUT_FOLDER_NAME = 'efficientnet-b7_ns_aa-original-mstd0.5_large_crop_100k'
39
+ SNAPSHOT_NAME_TEMPLATE = 'snapshot_{}.pth'
40
+ MAX_ITERS = 100000
41
+
42
+ FPS_RANGE = (15, 30)
43
+ SCALE_RANGE = (0.25, 1)
44
+ CRF_RANGE = (17, 40)
45
+ TUNE_VALUES = ['film', 'animation', 'grain', 'stillimage', 'fastdecode', 'zerolatency']
46
+
47
+ CROP_HEIGHT = 224
48
+ CROP_WIDTH = 192
49
+
50
+ PRETRAINED_WEIGHTS_PATH = 'external_data/noisy_student_efficientnet-b7.pth'
51
+ SNAPSHOTS_ROOT = 'snapshots'
52
+ LOGS_ROOT = 'logs'
53
+
54
+
55
+ class TrackTransform(object):
56
+ def __init__(self, fps_range, scale_range, crf_range, tune_values):
57
+ self.fps_range = fps_range
58
+ self.scale_range = scale_range
59
+ self.crf_range = crf_range
60
+ self.tune_values = tune_values
61
+
62
+ def get_params(self, src_fps, src_height, src_width):
63
+ if random.random() > 0.5:
64
+ return None
65
+
66
+ dst_fps = src_fps
67
+ if random.random() > 0.5:
68
+ dst_fps = random.randrange(*self.fps_range)
69
+
70
+ scale = 1.0
71
+ if random.random() > 0.5:
72
+ scale = random.uniform(*self.scale_range)
73
+
74
+ dst_height = round(scale * src_height) // 2 * 2
75
+ dst_width = round(scale * src_width) // 2 * 2
76
+
77
+ crf = random.randrange(*self.crf_range)
78
+ tune = random.choice(self.tune_values)
79
+
80
+ return dst_fps, dst_height, dst_width, crf, tune
81
+
82
+ def __call__(self, track_path, src_fps, dst_fps, dst_height, dst_width, crf, tune):
83
+ out, err = (
84
+ ffmpeg
85
+ .input(os.path.join(track_path, '%d.png'), framerate=src_fps, start_number=0)
86
+ .filter('fps', fps=dst_fps)
87
+ .filter('scale', dst_width, dst_height)
88
+ .output('pipe:', format='h264', vcodec='libx264', crf=crf, tune=tune)
89
+ .run(capture_stdout=True, quiet=True)
90
+ )
91
+ out, err = (
92
+ ffmpeg
93
+ .input('pipe:', format='h264')
94
+ .output('pipe:', format='rawvideo', pix_fmt='rgb24')
95
+ .run(capture_stdout=True, input=out, quiet=True)
96
+ )
97
+
98
+ imgs = np.frombuffer(out, dtype=np.uint8).reshape(-1, dst_height, dst_width, 3)
99
+
100
+ return imgs
101
+
102
+
103
+ class VisionTransform(ImageOnlyTransform):
104
+ def __init__(
105
+ self, transform, is_tensor=True, always_apply=False, p=1.0
106
+ ):
107
+ super(VisionTransform, self).__init__(always_apply, p)
108
+ self.transform = transform
109
+ self.is_tensor = is_tensor
110
+
111
+ def apply(self, image, **params):
112
+ if self.is_tensor:
113
+ return self.transform(image)
114
+ else:
115
+ return np.array(self.transform(Image.fromarray(image)))
116
+
117
+ def get_transform_init_args_names(self):
118
+ return ("transform")
119
+
120
+
121
+ def set_global_seed(seed):
122
+ torch.manual_seed(seed)
123
+ if torch.cuda.is_available():
124
+ torch.cuda.manual_seed_all(seed)
125
+ random.seed(seed)
126
+ np.random.seed(seed)
127
+
128
+
129
+ def prepare_cudnn(deterministic=None, benchmark=None):
130
+ # https://pytorch.org/docs/stable/notes/randomness.html#cudnn
131
+ if deterministic is None:
132
+ deterministic = os.environ.get("CUDNN_DETERMINISTIC", "True") == "True"
133
+ torch.backends.cudnn.deterministic = deterministic
134
+
135
+ # https://discuss.pytorch.org/t/how-should-i-disable-using-cudnn-in-my-code/38053/4
136
+ if benchmark is None:
137
+ benchmark = os.environ.get("CUDNN_BENCHMARK", "True") == "True"
138
+ torch.backends.cudnn.benchmark = benchmark
139
+
140
+
141
+ def main():
142
+ with open('config.yaml', 'r') as f:
143
+ config = yaml.load(f)
144
+
145
+ set_global_seed(SEED)
146
+ prepare_cudnn(deterministic=True, benchmark=True)
147
+
148
+ model = EfficientNet.from_name('efficientnet-b7', override_params={'num_classes': 1})
149
+ state = torch.load(PRETRAINED_WEIGHTS_PATH, map_location=lambda storage, loc: storage)
150
+ state.pop('_fc.weight')
151
+ state.pop('_fc.bias')
152
+ res = model.load_state_dict(state, strict=False)
153
+ assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
154
+ model = model.cuda()
155
+
156
+ normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
157
+ _, rand_augment, _ = transforms_imagenet_train((CROP_HEIGHT, CROP_WIDTH), auto_augment='original-mstd0.5',
158
+ separate=True)
159
+
160
+ train_dataset = TrackPairDataset(os.path.join(config['ARTIFACTS_PATH'], TRACKS_ROOT),
161
+ os.path.join(config['ARTIFACTS_PATH'], TRACK_PAIRS_FILE_NAME),
162
+ TRAIN_INDICES,
163
+ track_length=TRACK_LENGTH,
164
+ track_transform=TrackTransform(FPS_RANGE, SCALE_RANGE, CRF_RANGE, TUNE_VALUES),
165
+ image_transform=Compose([
166
+ SmallestMaxSize(CROP_WIDTH),
167
+ PadIfNeeded(CROP_HEIGHT, CROP_WIDTH),
168
+ HorizontalFlip(),
169
+ RandomCrop(CROP_HEIGHT, CROP_WIDTH),
170
+ VisionTransform(rand_augment, is_tensor=False, p=0.5),
171
+ normalize,
172
+ ToTensor()
173
+ ]), sequence_mode=False)
174
+
175
+ print('Train dataset size: {}.'.format(len(train_dataset)))
176
+
177
+ warmup_optimizer = torch.optim.SGD(model._fc.parameters(), INITIAL_LR, momentum=MOMENTUM,
178
+ weight_decay=WEIGHT_DECAY, nesterov=True)
179
+
180
+ full_optimizer = torch.optim.SGD(model.parameters(), INITIAL_LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY,
181
+ nesterov=True)
182
+ full_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(full_optimizer,
183
+ lambda iteration: (MAX_ITERS - iteration) / MAX_ITERS)
184
+
185
+ snapshots_root = os.path.join(config['ARTIFACTS_PATH'], SNAPSHOTS_ROOT, OUTPUT_FOLDER_NAME)
186
+ os.makedirs(snapshots_root)
187
+ log_root = os.path.join(config['ARTIFACTS_PATH'], LOGS_ROOT, OUTPUT_FOLDER_NAME)
188
+ os.makedirs(log_root)
189
+
190
+ writer = SummaryWriter(log_root)
191
+
192
+ iteration = 0
193
+ if iteration < NUM_WARMUP_ITERATIONS:
194
+ print('Start {} warmup iterations'.format(NUM_WARMUP_ITERATIONS))
195
+ model.eval()
196
+ model._fc.train()
197
+ for param in model.parameters():
198
+ param.requires_grad = False
199
+ for param in model._fc.parameters():
200
+ param.requires_grad = True
201
+ optimizer = warmup_optimizer
202
+ else:
203
+ print('Start without warmup iterations')
204
+ model.train()
205
+ optimizer = full_optimizer
206
+
207
+ max_lr = max(param_group["lr"] for param_group in full_optimizer.param_groups)
208
+ writer.add_scalar('train/max_lr', max_lr, iteration)
209
+
210
+ epoch = 0
211
+ fake_prob_dist = distributions.beta.Beta(0.5, 0.5)
212
+ while True:
213
+ epoch += 1
214
+ print('Epoch {} is in progress'.format(epoch))
215
+ loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, drop_last=True)
216
+ for samples in tqdm.tqdm(loader):
217
+ iteration += 1
218
+ fake_input_tensor = torch.cat(samples['fake']).cuda()
219
+ real_input_tensor = torch.cat(samples['real']).cuda()
220
+ target_fake_prob = fake_prob_dist.sample((len(fake_input_tensor),)).float().cuda()
221
+ fake_weight = target_fake_prob.view(-1, 1, 1, 1)
222
+
223
+ input_tensor = (1.0 - fake_weight) * real_input_tensor + fake_weight * fake_input_tensor
224
+ pred = model(input_tensor).flatten()
225
+
226
+ loss = F.binary_cross_entropy_with_logits(pred, target_fake_prob)
227
+
228
+ optimizer.zero_grad()
229
+ loss.backward()
230
+ optimizer.step()
231
+ if iteration > NUM_WARMUP_ITERATIONS:
232
+ full_lr_scheduler.step()
233
+ max_lr = max(param_group["lr"] for param_group in full_optimizer.param_groups)
234
+ writer.add_scalar('train/max_lr', max_lr, iteration)
235
+
236
+ writer.add_scalar('train/loss', loss.item(), iteration)
237
+
238
+ if iteration == NUM_WARMUP_ITERATIONS:
239
+ print('Stop warmup iterations')
240
+ model.train()
241
+ for param in model.parameters():
242
+ param.requires_grad = True
243
+ optimizer = full_optimizer
244
+
245
+ if iteration % SNAPSHOT_FREQUENCY == 0:
246
+ snapshot_name = SNAPSHOT_NAME_TEMPLATE.format(iteration)
247
+ snapshot_path = os.path.join(snapshots_root, snapshot_name)
248
+ print('Saving snapshot to {}'.format(snapshot_path))
249
+ torch.save(model.state_dict(), snapshot_path)
250
+
251
+ if iteration >= MAX_ITERS:
252
+ print('Stop training due to maximum iteration exceeded')
253
+ return
254
+
255
+
256
+ if __name__ == '__main__':
257
+ main()
train_b7_ns_aa_original_re_100k.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import os
3
+ import random
4
+ import tqdm
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ import torch
10
+ from torch import distributions
11
+ from torch.nn import functional as F
12
+ from torch.utils.data import DataLoader
13
+
14
+ from torch.utils.tensorboard import SummaryWriter
15
+
16
+ import ffmpeg
17
+
18
+ from albumentations import ImageOnlyTransform
19
+ from albumentations import SmallestMaxSize, HorizontalFlip, Normalize, Compose, RandomCrop
20
+ from albumentations.pytorch import ToTensor
21
+ from efficientnet_pytorch import EfficientNet
22
+
23
+ from timm.data.transforms_factory import transforms_imagenet_train
24
+ from timm.data.random_erasing import RandomErasing
25
+
26
+ from datasets import TrackPairDataset
27
+ from extract_tracks_from_videos import TRACK_LENGTH, TRACKS_ROOT
28
+ from generate_track_pairs import TRACK_PAIRS_FILE_NAME
29
+
30
+ SEED = 10
31
+ BATCH_SIZE = 8
32
+ TRAIN_INDICES = [9, 13, 17, 21, 25, 29, 33, 37]
33
+ INITIAL_LR = 0.005
34
+ MOMENTUM = 0.9
35
+ WEIGHT_DECAY = 1e-4
36
+ NUM_WORKERS = 8
37
+ NUM_WARMUP_ITERATIONS = 100
38
+ SNAPSHOT_FREQUENCY = 1000
39
+ OUTPUT_FOLDER_NAME = 'efficientnet-b7_ns_aa-original-mstd0.5_re_100k'
40
+ SNAPSHOT_NAME_TEMPLATE = 'snapshot_{}.pth'
41
+ MAX_ITERS = 100000
42
+
43
+ FPS_RANGE = (15, 30)
44
+ SCALE_RANGE = (0.25, 1)
45
+ CRF_RANGE = (17, 40)
46
+ TUNE_VALUES = ['film', 'animation', 'grain', 'stillimage', 'fastdecode', 'zerolatency']
47
+
48
+ RE_PROB = 0.2
49
+ RE_MODE = 'pixel'
50
+ RE_COUNT = 1
51
+ RE_NUM_SPLITS = 0
52
+
53
+ MIN_SIZE = 224
54
+ CROP_HEIGHT = 224
55
+ CROP_WIDTH = 192
56
+
57
+ PRETRAINED_WEIGHTS_PATH = 'external_data/noisy_student_efficientnet-b7.pth'
58
+ SNAPSHOTS_ROOT = 'snapshots'
59
+ LOGS_ROOT = 'logs'
60
+
61
+
62
+ class TrackTransform(object):
63
+ def __init__(self, fps_range, scale_range, crf_range, tune_values):
64
+ self.fps_range = fps_range
65
+ self.scale_range = scale_range
66
+ self.crf_range = crf_range
67
+ self.tune_values = tune_values
68
+
69
+ def get_params(self, src_fps, src_height, src_width):
70
+ if random.random() > 0.5:
71
+ return None
72
+
73
+ dst_fps = src_fps
74
+ if random.random() > 0.5:
75
+ dst_fps = random.randrange(*self.fps_range)
76
+
77
+ scale = 1.0
78
+ if random.random() > 0.5:
79
+ scale = random.uniform(*self.scale_range)
80
+
81
+ dst_height = round(scale * src_height) // 2 * 2
82
+ dst_width = round(scale * src_width) // 2 * 2
83
+
84
+ crf = random.randrange(*self.crf_range)
85
+ tune = random.choice(self.tune_values)
86
+
87
+ return dst_fps, dst_height, dst_width, crf, tune
88
+
89
+ def __call__(self, track_path, src_fps, dst_fps, dst_height, dst_width, crf, tune):
90
+ out, err = (
91
+ ffmpeg
92
+ .input(os.path.join(track_path, '%d.png'), framerate=src_fps, start_number=0)
93
+ .filter('fps', fps=dst_fps)
94
+ .filter('scale', dst_width, dst_height)
95
+ .output('pipe:', format='h264', vcodec='libx264', crf=crf, tune=tune)
96
+ .run(capture_stdout=True, quiet=True)
97
+ )
98
+ out, err = (
99
+ ffmpeg
100
+ .input('pipe:', format='h264')
101
+ .output('pipe:', format='rawvideo', pix_fmt='rgb24')
102
+ .run(capture_stdout=True, input=out, quiet=True)
103
+ )
104
+
105
+ imgs = np.frombuffer(out, dtype=np.uint8).reshape(-1, dst_height, dst_width, 3)
106
+
107
+ return imgs
108
+
109
+
110
+ class VisionTransform(ImageOnlyTransform):
111
+ def __init__(
112
+ self, transform, is_tensor=True, always_apply=False, p=1.0
113
+ ):
114
+ super(VisionTransform, self).__init__(always_apply, p)
115
+ self.transform = transform
116
+ self.is_tensor = is_tensor
117
+
118
+ def apply(self, image, **params):
119
+ if self.is_tensor:
120
+ return self.transform(image)
121
+ else:
122
+ return np.array(self.transform(Image.fromarray(image)))
123
+
124
+ def get_transform_init_args_names(self):
125
+ return ("transform")
126
+
127
+
128
+ def set_global_seed(seed):
129
+ torch.manual_seed(seed)
130
+ if torch.cuda.is_available():
131
+ torch.cuda.manual_seed_all(seed)
132
+ random.seed(seed)
133
+ np.random.seed(seed)
134
+
135
+
136
+ def prepare_cudnn(deterministic=None, benchmark=None):
137
+ # https://pytorch.org/docs/stable/notes/randomness.html#cudnn
138
+ if deterministic is None:
139
+ deterministic = os.environ.get("CUDNN_DETERMINISTIC", "True") == "True"
140
+ torch.backends.cudnn.deterministic = deterministic
141
+
142
+ # https://discuss.pytorch.org/t/how-should-i-disable-using-cudnn-in-my-code/38053/4
143
+ if benchmark is None:
144
+ benchmark = os.environ.get("CUDNN_BENCHMARK", "True") == "True"
145
+ torch.backends.cudnn.benchmark = benchmark
146
+
147
+
148
+ def main():
149
+ with open('config.yaml', 'r') as f:
150
+ config = yaml.load(f)
151
+
152
+ set_global_seed(SEED)
153
+ prepare_cudnn(deterministic=True, benchmark=True)
154
+
155
+ model = EfficientNet.from_name('efficientnet-b7', override_params={'num_classes': 1})
156
+ state = torch.load(PRETRAINED_WEIGHTS_PATH, map_location=lambda storage, loc: storage)
157
+ state.pop('_fc.weight')
158
+ state.pop('_fc.bias')
159
+ res = model.load_state_dict(state, strict=False)
160
+ assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
161
+ model = model.cuda()
162
+
163
+ normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
164
+ _, rand_augment, _ = transforms_imagenet_train((CROP_HEIGHT, CROP_WIDTH), auto_augment='original-mstd0.5',
165
+ separate=True)
166
+
167
+ train_dataset = TrackPairDataset(os.path.join(config['ARTIFACTS_PATH'], TRACKS_ROOT),
168
+ os.path.join(config['ARTIFACTS_PATH'], TRACK_PAIRS_FILE_NAME),
169
+ TRAIN_INDICES,
170
+ track_length=TRACK_LENGTH,
171
+ track_transform=TrackTransform(FPS_RANGE, SCALE_RANGE, CRF_RANGE, TUNE_VALUES),
172
+ image_transform=Compose([
173
+ SmallestMaxSize(MIN_SIZE),
174
+ HorizontalFlip(),
175
+ RandomCrop(CROP_HEIGHT, CROP_WIDTH),
176
+ VisionTransform(rand_augment, is_tensor=False, p=0.5),
177
+ normalize,
178
+ ToTensor(),
179
+ VisionTransform(
180
+ RandomErasing(probability=RE_PROB, mode=RE_MODE, max_count=RE_COUNT,
181
+ num_splits=RE_NUM_SPLITS, device='cpu'), is_tensor=True)
182
+ ]), sequence_mode=False)
183
+
184
+ print('Train dataset size: {}.'.format(len(train_dataset)))
185
+
186
+ warmup_optimizer = torch.optim.SGD(model._fc.parameters(), INITIAL_LR, momentum=MOMENTUM,
187
+ weight_decay=WEIGHT_DECAY, nesterov=True)
188
+
189
+ full_optimizer = torch.optim.SGD(model.parameters(), INITIAL_LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY,
190
+ nesterov=True)
191
+ full_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(full_optimizer,
192
+ lambda iteration: (MAX_ITERS - iteration) / MAX_ITERS)
193
+
194
+ snapshots_root = os.path.join(config['ARTIFACTS_PATH'], SNAPSHOTS_ROOT, OUTPUT_FOLDER_NAME)
195
+ os.makedirs(snapshots_root)
196
+ log_root = os.path.join(config['ARTIFACTS_PATH'], LOGS_ROOT, OUTPUT_FOLDER_NAME)
197
+ os.makedirs(log_root)
198
+
199
+ writer = SummaryWriter(log_root)
200
+
201
+ iteration = 0
202
+ if iteration < NUM_WARMUP_ITERATIONS:
203
+ print('Start {} warmup iterations'.format(NUM_WARMUP_ITERATIONS))
204
+ model.eval()
205
+ model._fc.train()
206
+ for param in model.parameters():
207
+ param.requires_grad = False
208
+ for param in model._fc.parameters():
209
+ param.requires_grad = True
210
+ optimizer = warmup_optimizer
211
+ else:
212
+ print('Start without warmup iterations')
213
+ model.train()
214
+ optimizer = full_optimizer
215
+
216
+ max_lr = max(param_group["lr"] for param_group in full_optimizer.param_groups)
217
+ writer.add_scalar('train/max_lr', max_lr, iteration)
218
+
219
+ epoch = 0
220
+ fake_prob_dist = distributions.beta.Beta(0.5, 0.5)
221
+ while True:
222
+ epoch += 1
223
+ print('Epoch {} is in progress'.format(epoch))
224
+ loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, drop_last=True)
225
+ for samples in tqdm.tqdm(loader):
226
+ iteration += 1
227
+ fake_input_tensor = torch.cat(samples['fake']).cuda()
228
+ real_input_tensor = torch.cat(samples['real']).cuda()
229
+ target_fake_prob = fake_prob_dist.sample((len(fake_input_tensor),)).float().cuda()
230
+ fake_weight = target_fake_prob.view(-1, 1, 1, 1)
231
+
232
+ input_tensor = (1.0 - fake_weight) * real_input_tensor + fake_weight * fake_input_tensor
233
+ pred = model(input_tensor).flatten()
234
+
235
+ loss = F.binary_cross_entropy_with_logits(pred, target_fake_prob)
236
+
237
+ optimizer.zero_grad()
238
+ loss.backward()
239
+ optimizer.step()
240
+ if iteration > NUM_WARMUP_ITERATIONS:
241
+ full_lr_scheduler.step()
242
+ max_lr = max(param_group["lr"] for param_group in full_optimizer.param_groups)
243
+ writer.add_scalar('train/max_lr', max_lr, iteration)
244
+
245
+ writer.add_scalar('train/loss', loss.item(), iteration)
246
+
247
+ if iteration == NUM_WARMUP_ITERATIONS:
248
+ print('Stop warmup iterations')
249
+ model.train()
250
+ for param in model.parameters():
251
+ param.requires_grad = True
252
+ optimizer = full_optimizer
253
+
254
+ if iteration % SNAPSHOT_FREQUENCY == 0:
255
+ snapshot_name = SNAPSHOT_NAME_TEMPLATE.format(iteration)
256
+ snapshot_path = os.path.join(snapshots_root, snapshot_name)
257
+ print('Saving snapshot to {}'.format(snapshot_path))
258
+ torch.save(model.state_dict(), snapshot_path)
259
+
260
+ if iteration >= MAX_ITERS:
261
+ print('Stop training due to maximum iteration exceeded')
262
+ return
263
+
264
+
265
+ if __name__ == '__main__':
266
+ main()
train_b7_ns_seq_aa_original_100k.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import os
3
+ import random
4
+ import tqdm
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ import torch
10
+ from torch import nn
11
+ from torch import distributions
12
+ from torch.nn import functional as F
13
+ from torch.utils.data import DataLoader
14
+
15
+ from torch.utils.tensorboard import SummaryWriter
16
+
17
+ import ffmpeg
18
+
19
+ from albumentations import ImageOnlyTransform
20
+ from albumentations import SmallestMaxSize, HorizontalFlip, Normalize, Compose, RandomCrop
21
+ from albumentations.pytorch import ToTensor
22
+ from efficientnet_pytorch import EfficientNet
23
+ from efficientnet_pytorch.model import MBConvBlock
24
+
25
+ from timm.data.transforms_factory import transforms_imagenet_train
26
+
27
+ from datasets import TrackPairDataset
28
+ from extract_tracks_from_videos import TRACK_LENGTH, TRACKS_ROOT
29
+ from generate_track_pairs import TRACK_PAIRS_FILE_NAME
30
+
31
+ SEED = 20
32
+ BATCH_SIZE = 8
33
+ TRAIN_INDICES = [19, 21, 23, 25, 27, 29, 31]
34
+ INITIAL_LR = 0.005
35
+ MOMENTUM = 0.9
36
+ WEIGHT_DECAY = 1e-4
37
+ NUM_WORKERS = 8
38
+ NUM_WARMUP_ITERATIONS = 100
39
+ SNAPSHOT_FREQUENCY = 1000
40
+ OUTPUT_FOLDER_NAME = 'efficientnet-b7_ns_seq_aa-original-mstd0.5_100k'
41
+ SNAPSHOT_NAME_TEMPLATE = 'snapshot_{}.pth'
42
+ FINAL_SNAPSHOT_NAME = 'final.pth'
43
+ MAX_ITERS = 100000
44
+
45
+ FPS_RANGE = (15, 30)
46
+ SCALE_RANGE = (0.25, 1)
47
+ CRF_RANGE = (17, 40)
48
+ TUNE_VALUES = ['film', 'animation', 'grain', 'stillimage', 'fastdecode', 'zerolatency']
49
+
50
+ MIN_SIZE = 224
51
+ CROP_HEIGHT = 224
52
+ CROP_WIDTH = 192
53
+
54
+ PRETRAINED_WEIGHTS_PATH = 'external_data/noisy_student_efficientnet-b7.pth'
55
+ SNAPSHOTS_ROOT = 'snapshots'
56
+ LOGS_ROOT = 'logs'
57
+
58
+
59
+ class SeqExpandConv(nn.Module):
60
+ def __init__(self, in_channels, out_channels, seq_length):
61
+ super(SeqExpandConv, self).__init__()
62
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=(3, 1, 1), padding=(1, 0, 0), bias=False)
63
+ self.seq_length = seq_length
64
+
65
+ def forward(self, x):
66
+ batch_size, in_channels, height, width = x.shape
67
+ x = x.view(batch_size // self.seq_length, self.seq_length, in_channels, height, width)
68
+ x = self.conv(x.transpose(1, 2).contiguous()).transpose(2, 1).contiguous()
69
+ x = x.flatten(0, 1)
70
+ return x
71
+
72
+
73
+ class TrackTransform(object):
74
+ def __init__(self, fps_range, scale_range, crf_range, tune_values):
75
+ self.fps_range = fps_range
76
+ self.scale_range = scale_range
77
+ self.crf_range = crf_range
78
+ self.tune_values = tune_values
79
+
80
+ def get_params(self, src_fps, src_height, src_width):
81
+ if random.random() > 0.5:
82
+ return None
83
+
84
+ dst_fps = src_fps
85
+ if random.random() > 0.5:
86
+ dst_fps = random.randrange(*self.fps_range)
87
+
88
+ scale = 1.0
89
+ if random.random() > 0.5:
90
+ scale = random.uniform(*self.scale_range)
91
+
92
+ dst_height = round(scale * src_height) // 2 * 2
93
+ dst_width = round(scale * src_width) // 2 * 2
94
+
95
+ crf = random.randrange(*self.crf_range)
96
+ tune = random.choice(self.tune_values)
97
+
98
+ return dst_fps, dst_height, dst_width, crf, tune
99
+
100
+ def __call__(self, track_path, src_fps, dst_fps, dst_height, dst_width, crf, tune):
101
+ out, err = (
102
+ ffmpeg
103
+ .input(os.path.join(track_path, '%d.png'), framerate=src_fps, start_number=0)
104
+ .filter('fps', fps=dst_fps)
105
+ .filter('scale', dst_width, dst_height)
106
+ .output('pipe:', format='h264', vcodec='libx264', crf=crf, tune=tune)
107
+ .run(capture_stdout=True, quiet=True)
108
+ )
109
+ out, err = (
110
+ ffmpeg
111
+ .input('pipe:', format='h264')
112
+ .output('pipe:', format='rawvideo', pix_fmt='rgb24')
113
+ .run(capture_stdout=True, input=out, quiet=True)
114
+ )
115
+
116
+ imgs = np.frombuffer(out, dtype=np.uint8).reshape(-1, dst_height, dst_width, 3)
117
+
118
+ return imgs
119
+
120
+
121
+ class VisionTransform(ImageOnlyTransform):
122
+ def __init__(
123
+ self, transform, always_apply=False, p=1.0
124
+ ):
125
+ super(VisionTransform, self).__init__(always_apply, p)
126
+ self.transform = transform
127
+
128
+ def apply(self, image, **params):
129
+ return np.array(self.transform(Image.fromarray(image)))
130
+
131
+ def get_transform_init_args_names(self):
132
+ return ("transform")
133
+
134
+
135
+ def set_global_seed(seed):
136
+ torch.manual_seed(seed)
137
+ if torch.cuda.is_available():
138
+ torch.cuda.manual_seed_all(seed)
139
+ random.seed(seed)
140
+ np.random.seed(seed)
141
+
142
+
143
+ def prepare_cudnn(deterministic=None, benchmark=None):
144
+ # https://pytorch.org/docs/stable/notes/randomness.html#cudnn
145
+ if deterministic is None:
146
+ deterministic = os.environ.get("CUDNN_DETERMINISTIC", "True") == "True"
147
+ torch.backends.cudnn.deterministic = deterministic
148
+
149
+ # https://discuss.pytorch.org/t/how-should-i-disable-using-cudnn-in-my-code/38053/4
150
+ if benchmark is None:
151
+ benchmark = os.environ.get("CUDNN_BENCHMARK", "True") == "True"
152
+ torch.backends.cudnn.benchmark = benchmark
153
+
154
+
155
+ def main():
156
+ with open('config.yaml', 'r') as f:
157
+ config = yaml.load(f)
158
+
159
+ set_global_seed(SEED)
160
+ prepare_cudnn(deterministic=True, benchmark=True)
161
+
162
+ model = EfficientNet.from_name('efficientnet-b7', override_params={'num_classes': 1})
163
+ state = torch.load(PRETRAINED_WEIGHTS_PATH, map_location=lambda storage, loc: storage)
164
+ state.pop('_fc.weight')
165
+ state.pop('_fc.bias')
166
+ res = model.load_state_dict(state, strict=False)
167
+ assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
168
+
169
+ for module in model.modules():
170
+ if isinstance(module, MBConvBlock):
171
+ if module._block_args.expand_ratio != 1:
172
+ expand_conv = module._expand_conv
173
+ seq_expand_conv = SeqExpandConv(expand_conv.in_channels, expand_conv.out_channels, len(TRAIN_INDICES))
174
+ seq_expand_conv.conv.weight.data[:, :, 0, :, :].copy_(expand_conv.weight.data / 3)
175
+ seq_expand_conv.conv.weight.data[:, :, 1, :, :].copy_(expand_conv.weight.data / 3)
176
+ seq_expand_conv.conv.weight.data[:, :, 2, :, :].copy_(expand_conv.weight.data / 3)
177
+ module._expand_conv = seq_expand_conv
178
+
179
+ model = model.cuda()
180
+
181
+ normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
182
+ _, rand_augment, _ = transforms_imagenet_train((CROP_HEIGHT, CROP_WIDTH), auto_augment='original-mstd0.5',
183
+ separate=True)
184
+
185
+ train_dataset = TrackPairDataset(os.path.join(config['ARTIFACTS_PATH'], TRACKS_ROOT),
186
+ os.path.join(config['ARTIFACTS_PATH'], TRACK_PAIRS_FILE_NAME),
187
+ TRAIN_INDICES,
188
+ track_length=TRACK_LENGTH,
189
+ track_transform=TrackTransform(FPS_RANGE, SCALE_RANGE, CRF_RANGE, TUNE_VALUES),
190
+ image_transform=Compose([
191
+ SmallestMaxSize(MIN_SIZE),
192
+ HorizontalFlip(),
193
+ RandomCrop(CROP_HEIGHT, CROP_WIDTH),
194
+ VisionTransform(rand_augment, p=0.5),
195
+ normalize,
196
+ ToTensor()
197
+ ]), sequence_mode=True)
198
+
199
+ print('Train dataset size: {}.'.format(len(train_dataset)))
200
+
201
+ warmup_optimizer = torch.optim.SGD(model._fc.parameters(), INITIAL_LR, momentum=MOMENTUM,
202
+ weight_decay=WEIGHT_DECAY, nesterov=True)
203
+
204
+ full_optimizer = torch.optim.SGD(model.parameters(), INITIAL_LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY,
205
+ nesterov=True)
206
+ full_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(full_optimizer,
207
+ lambda iteration: (MAX_ITERS - iteration) / MAX_ITERS)
208
+
209
+ snapshots_root = os.path.join(config['ARTIFACTS_PATH'], SNAPSHOTS_ROOT, OUTPUT_FOLDER_NAME)
210
+ os.makedirs(snapshots_root)
211
+ log_root = os.path.join(config['ARTIFACTS_PATH'], LOGS_ROOT, OUTPUT_FOLDER_NAME)
212
+ os.makedirs(log_root)
213
+
214
+ writer = SummaryWriter(log_root)
215
+
216
+ iteration = 0
217
+ if iteration < NUM_WARMUP_ITERATIONS:
218
+ print('Start {} warmup iterations'.format(NUM_WARMUP_ITERATIONS))
219
+ model.eval()
220
+ model._fc.train()
221
+ for param in model.parameters():
222
+ param.requires_grad = False
223
+ for param in model._fc.parameters():
224
+ param.requires_grad = True
225
+ optimizer = warmup_optimizer
226
+ else:
227
+ print('Start without warmup iterations')
228
+ model.train()
229
+ optimizer = full_optimizer
230
+
231
+ max_lr = max(param_group["lr"] for param_group in full_optimizer.param_groups)
232
+ writer.add_scalar('train/max_lr', max_lr, iteration)
233
+
234
+ epoch = 0
235
+ fake_prob_dist = distributions.beta.Beta(0.5, 0.5)
236
+ while True:
237
+ epoch += 1
238
+ print('Epoch {} is in progress'.format(epoch))
239
+ loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, drop_last=True)
240
+ for samples in tqdm.tqdm(loader):
241
+ iteration += 1
242
+ fake_input_tensor = torch.stack(samples['fake']).transpose(0, 1).cuda()
243
+ real_input_tensor = torch.stack(samples['real']).transpose(0, 1).cuda()
244
+ target_fake_prob = fake_prob_dist.sample((len(fake_input_tensor),)).float().cuda()
245
+ fake_weight = target_fake_prob.view(-1, 1, 1, 1, 1)
246
+
247
+ input_tensor = (1.0 - fake_weight) * real_input_tensor + fake_weight * fake_input_tensor
248
+ pred = model(input_tensor.flatten(0, 1)).flatten()
249
+
250
+ loss = F.binary_cross_entropy_with_logits(pred, target_fake_prob.repeat_interleave(len(TRAIN_INDICES)))
251
+
252
+ optimizer.zero_grad()
253
+ loss.backward()
254
+ optimizer.step()
255
+ if iteration > NUM_WARMUP_ITERATIONS:
256
+ full_lr_scheduler.step()
257
+ max_lr = max(param_group["lr"] for param_group in full_optimizer.param_groups)
258
+ writer.add_scalar('train/max_lr', max_lr, iteration)
259
+
260
+ writer.add_scalar('train/loss', loss.item(), iteration)
261
+
262
+ if iteration == NUM_WARMUP_ITERATIONS:
263
+ print('Stop warmup iterations')
264
+ model.train()
265
+ for param in model.parameters():
266
+ param.requires_grad = True
267
+ optimizer = full_optimizer
268
+
269
+ if iteration % SNAPSHOT_FREQUENCY == 0:
270
+ snapshot_name = SNAPSHOT_NAME_TEMPLATE.format(iteration)
271
+ snapshot_path = os.path.join(snapshots_root, snapshot_name)
272
+ print('Saving snapshot to {}'.format(snapshot_path))
273
+ torch.save(model.state_dict(), snapshot_path)
274
+
275
+ if iteration >= MAX_ITERS:
276
+ print('Stop training due to maximum iteration exceeded')
277
+ return
278
+
279
+
280
+ if __name__ == '__main__':
281
+ main()