yuxi-liu-wired commited on
Commit
0decf42
·
1 Parent(s): 6ad02bb
Files changed (48) hide show
  1. LICENSE +21 -0
  2. README.md +5 -0
  3. RepVGG-main/.gitignore +2 -0
  4. RepVGG-main/LICENSE +21 -0
  5. RepVGG-main/README.md +334 -0
  6. RepVGG-main/arch.PNG +0 -0
  7. RepVGG-main/data/__init__.py +1 -0
  8. RepVGG-main/data/build.py +184 -0
  9. RepVGG-main/data/cached_image_folder.py +244 -0
  10. RepVGG-main/data/lmdb_dataset.py +164 -0
  11. RepVGG-main/data/samplers.py +21 -0
  12. RepVGG-main/data/zipreader.py +96 -0
  13. RepVGG-main/example_pspnet.py +161 -0
  14. RepVGG-main/jizhi_submit_train_repvgg.py +34 -0
  15. RepVGG-main/main.py +414 -0
  16. RepVGG-main/quantization/quant_qat_train.py +426 -0
  17. RepVGG-main/quantization/repvgg_quantized.py +63 -0
  18. RepVGG-main/repvgg.py +303 -0
  19. RepVGG-main/repvggplus.py +293 -0
  20. RepVGG-main/repvggplus_custom_L2.py +268 -0
  21. RepVGG-main/se_block.py +22 -0
  22. RepVGG-main/speed_acc.PNG +0 -0
  23. RepVGG-main/table.PNG +0 -0
  24. RepVGG-main/tools/convert.py +46 -0
  25. RepVGG-main/tools/insert_bn.py +217 -0
  26. RepVGG-main/tools/verify.py +30 -0
  27. RepVGG-main/train/config.py +217 -0
  28. RepVGG-main/train/cutout.py +55 -0
  29. RepVGG-main/train/logger.py +41 -0
  30. RepVGG-main/train/lr_scheduler.py +101 -0
  31. RepVGG-main/train/optimizer.py +71 -0
  32. RepVGG-main/train/randaug.py +407 -0
  33. RepVGG-main/utils.py +249 -0
  34. models/RepVGG-A0-train.pth +3 -0
  35. models/RepVGG-A1-train.pth +3 -0
  36. models/RepVGG-A2-train.pth +3 -0
  37. models/RepVGG-B0-train.pth +3 -0
  38. models/RepVGG-B1-train.pth +3 -0
  39. models/RepVGG-B1g2-train.pth +3 -0
  40. models/RepVGG-B1g4-train.pth +3 -0
  41. models/RepVGG-B2-train.pth +3 -0
  42. models/RepVGG-B2g4-200epochs-train.pth +3 -0
  43. models/RepVGG-B2g4-train.pth +3 -0
  44. models/RepVGG-B3-200epochs-train.pth +3 -0
  45. models/RepVGG-B3g4-200epochs-train.pth +3 -0
  46. models/RepVGG-D2se-200epochs-train.pth +3 -0
  47. models/RepVGGplus-L2pse-train-custom-wd-acc84.16.pth +3 -0
  48. models/RepVGGplus-L2pse-train256-acc84.06.pth +3 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 DingXiaoH
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ## RepVGG (PyTorch)
2
+
3
+ Copied from <https://github.com/DingXiaoH/RepVGG>. The original repo is in `RepVGG-main` folder.
4
+
5
+ The models are in `models` folder.
RepVGG-main/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .idea/
2
+ *nori*
RepVGG-main/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 DingXiaoH
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
RepVGG-main/README.md ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RepVGG: Making VGG-style ConvNets Great Again (CVPR-2021) (PyTorch)
2
+
3
+ ## Highlights (Sep. 1st, 2022)
4
+
5
+ RepVGG and the methodology of re-parameterization have been used in **YOLOv6** ([paper](https://arxiv.org/abs/2209.02976), [code](https://github.com/meituan/YOLOv6)) and **YOLOv7** ([paper](https://arxiv.org/abs/2207.02696), [code](https://github.com/WongKinYiu/yolov7)).
6
+
7
+ I have re-organized this repository and released the RepVGGplus-L2pse model with 84.06% ImageNet accuracy. Will release more RepVGGplus models in this month.
8
+
9
+ ## Introduction
10
+
11
+ This is a super simple ConvNet architecture that achieves over **84% top-1 accuracy on ImageNet** with a VGG-like architecture! This repo contains the **pretrained models**, code for building the model, training, and the conversion from training-time model to inference-time, and **an example of using RepVGG for semantic segmentation**.
12
+
13
+ [The MegEngine version](https://github.com/megvii-model/RepVGG)
14
+
15
+ [TensorRT implemention with C++ API by @upczww](https://github.com/upczww/TensorRT-RepVGG). Great work!
16
+
17
+ [Another PyTorch implementation by @zjykzj](https://github.com/ZJCV/ZCls). He also presented detailed benchmarks [here](https://zcls.readthedocs.io/en/latest/benchmark-repvgg/). Nice work!
18
+
19
+ Included in a famous PyTorch model zoo https://github.com/rwightman/pytorch-image-models.
20
+
21
+ [Objax implementation and models by @benjaminjellis](https://github.com/benjaminjellis/Objax-RepVGG). Great work!
22
+
23
+ Included in the [MegEngine Basecls model zoo](https://github.com/megvii-research/basecls/tree/main/zoo/public/repvgg).
24
+
25
+ Citation:
26
+
27
+ @inproceedings{ding2021repvgg,
28
+ title={Repvgg: Making vgg-style convnets great again},
29
+ author={Ding, Xiaohan and Zhang, Xiangyu and Ma, Ningning and Han, Jungong and Ding, Guiguang and Sun, Jian},
30
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
31
+ pages={13733--13742},
32
+ year={2021}
33
+ }
34
+
35
+
36
+ ## From RepVGG to RepVGGplus
37
+
38
+ We have released an improved architecture named RepVGGplus on top of the original version presented in the CVPR-2021 paper.
39
+
40
+ 1. RepVGGplus is deeper
41
+
42
+ 2. RepVGGplus has auxiliary classifiers during training, which can also be removed for inference
43
+
44
+ 3. (Optional) RepVGGplus uses Squeeze-and-Excitation blocks to further improve the performance.
45
+
46
+ RepVGGplus outperformed several recent visual transformers with a top-1 accuracy of **84.06%** and higher throughput. Our training script is based on [codebase of Swin Transformer](https://github.com/microsoft/Swin-Transformer/). The throughput is tested with the Swin codebase as well. We would like to thank the authors of [Swin](https://arxiv.org/abs/2103.14030) for their clean and well-structured code.
47
+
48
+ | Model | Train image size | Test size | ImageNet top-1 | Throughput (examples/second), 320, batchsize=128, 2080Ti) |
49
+ | ------------- |:-------------:| -----:| -----:| -----:|
50
+ | RepVGGplus-L2pse | 256 | 320 | **84.06%** |**147** |
51
+ | Swin Transformer | 320 | 320 | 84.0% |102 |
52
+
53
+ ("pse" means Squeeze-and-Excitation blocks after ReLU.)
54
+
55
+ Download this model: [Google Drive](https://drive.google.com/file/d/1x8VNLpfuLzg0xXDVIZv9yIIgqnSMoK-W/view?usp=sharing) or [Baidu Cloud](https://pan.baidu.com/s/19YwKCTSPVgJu5Ueg0Q78-w?pwd=rvgg).
56
+
57
+ To train or finetune it, slightly change your training code like this:
58
+ ```
59
+ # Build model and data loader as usual
60
+ for samples, targets in enumerate(train_data_loader):
61
+ # ......
62
+ outputs = model(samples) # Your original code
63
+ if type(outputs) is dict:
64
+ # A training-time RepVGGplus outputs a dict. The items are:
65
+ # 'main': the output of the final layer
66
+ # '*aux*': the output of auxiliary classifiers
67
+ loss = 0
68
+ for name, pred in outputs.items():
69
+ if 'aux' in name:
70
+ loss += 0.1 * criterion(pred, targets) # Assume "criterion" is cross-entropy for classification
71
+ else:
72
+ loss += criterion(pred, targets)
73
+ else:
74
+ loss = criterion(outputs, targets) # Your original code
75
+ # Backward as usual
76
+ # ......
77
+ ```
78
+
79
+ To use it for downstream tasks like semantic segmentation, just discard the aux classifiers and the final FC layer.
80
+
81
+ Pleased note that the custom weight decay trick I described last year turned out to be insignificant in our recent experiments (84.16% ImageNet acc and negligible improvements on other tasks), so I decided to stop using it as a new feature of RepVGGplus. You may try it optionally on your task. Please refer to the last part of this page for details.
82
+
83
+
84
+ ## Use our pretrained model
85
+
86
+ You may download _all_ of the ImageNet-pretrained models reported in the paper from Google Drive (https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq?usp=sharing) or Baidu Cloud (https://pan.baidu.com/s/1nCsZlMynnJwbUBKn0ch7dQ, the access code is "rvgg"). For the ease of transfer learning on other tasks, they are all training-time models (with identity and 1x1 branches). You may test the accuracy by running
87
+ ```
88
+ python -m torch.distributed.launch --nproc_per_node 1 --master_port 12349 main.py --arch [model name] --data-path [/path/to/imagenet] --batch-size 32 --tag test --eval --resume [/path/to/weights/file] --opts DATA.DATASET imagenet DATA.IMG_SIZE [224 or 320]
89
+ ```
90
+ The valid model names include
91
+ ```
92
+ RepVGGplus-L2pse, RepVGG-A0, RepVGG-A1, RepVGG-A2, RepVGG-B0, RepVGG-B1, RepVGG-B1g2, RepVGG-B1g4, RepVGG-B2, RepVGG-B2g2, RepVGG-B2g4, RepVGG-B3, RepVGG-B3g2, RepVGG-B3g4
93
+ ```
94
+
95
+ ## Convert a training-time RepVGG into the inference-time structure
96
+
97
+ For a RepVGG model or a model with RepVGG as one of its components (e.g., the backbone), you can convert the whole model by simply calling **switch_to_deploy** of every RepVGG block. This is the recommended way. Examples are shown in ```tools/convert.py``` and ```example_pspnet.py```.
98
+ ```
99
+ for module in model.modules():
100
+ if hasattr(module, 'switch_to_deploy'):
101
+ module.switch_to_deploy()
102
+ ```
103
+ We have also released a script for the conversion. For example,
104
+ ```
105
+ python convert.py RepVGGplus-L2pse-train256-acc84.06.pth RepVGGplus-L2pse-deploy.pth -a RepVGGplus-L2pse
106
+ ```
107
+ Then you may build the inference-time model with ```--deploy```, load the converted weights and test
108
+ ```
109
+ python -m torch.distributed.launch --nproc_per_node 1 --master_port 12349 main.py --arch RepVGGplus-L2pse --data-path [/path/to/imagenet] --batch-size 32 --tag test --eval --resume RepVGGplus-L2pse-deploy.pth --deploy --opts DATA.DATASET imagenet DATA.IMG_SIZE [224 or 320]
110
+ ```
111
+
112
+ Except for the final conversion after training, you may want to get the equivalent kernel and bias **during training** in a **differentiable** way at any time (```get_equivalent_kernel_bias``` in ```repvgg.py```). This may help training-based pruning or quantization.
113
+
114
+ ## Train from scratch
115
+
116
+ ### Reproduce RepVGGplus-L2pse (not presented in the paper)
117
+
118
+ To train the recently released RepVGGplus-L2pse from scratch, activate mixup and use ```--AUG.PRESET raug15``` for RandAug.
119
+ ```
120
+ python -m torch.distributed.launch --nproc_per_node 8 --master_port 12349 main.py --arch RepVGGplus-L2pse --data-path [/path/to/imagenet] --batch-size 32 --tag train_from_scratch --output-dir /path/to/save/the/log/and/checkpoints --opts TRAIN.EPOCHS 300 TRAIN.BASE_LR 0.1 TRAIN.WEIGHT_DECAY 4e-5 TRAIN.WARMUP_EPOCHS 5 MODEL.LABEL_SMOOTHING 0.1 AUG.PRESET raug15 AUG.MIXUP 0.2 DATA.DATASET imagenet DATA.IMG_SIZE 256 DATA.TEST_SIZE 320
121
+ ```
122
+
123
+ ### Reproduce original RepVGG results reported in the paper
124
+
125
+ To reproduce the models reported in the CVPR-2021 paper, use no mixup nor RandAug.
126
+ ```
127
+ python -m torch.distributed.launch --nproc_per_node 8 --master_port 12349 main.py --arch [model name] --data-path [/path/to/imagenet] --batch-size 32 --tag train_from_scratch --output-dir /path/to/save/the/log/and/checkpoints --opts TRAIN.EPOCHS 300 TRAIN.BASE_LR 0.1 TRAIN.WEIGHT_DECAY 1e-4 TRAIN.WARMUP_EPOCHS 5 MODEL.LABEL_SMOOTHING 0.1 AUG.PRESET weak AUG.MIXUP 0.0 DATA.DATASET imagenet DATA.IMG_SIZE 224
128
+ ```
129
+ The original RepVGG models were trained in 120 epochs with cosine learning rate decay from 0.1 to 0. We used 8 GPUs, global batch size of 256, weight decay of 1e-4 (no weight decay on fc.bias, bn.bias, rbr_dense.bn.weight and rbr_1x1.bn.weight) (weight decay on rbr_identity.weight makes little difference, and it is better to use it in most of the cases), and the same simple data preprocssing as the PyTorch official example:
130
+ ```
131
+ trans = transforms.Compose([
132
+ transforms.RandomResizedCrop(224),
133
+ transforms.RandomHorizontalFlip(),
134
+ transforms.ToTensor(),
135
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
136
+ ```
137
+
138
+
139
+
140
+
141
+
142
+
143
+
144
+ ## Other released models not presented in the paper
145
+
146
+ ***Apr 25, 2021*** A deeper RepVGG model achieves **83.55\% top-1 accuracy on ImageNet** with [SE](https://openaccess.thecvf.com/content_cvpr_2018/html/Hu_Squeeze-and-Excitation_Networks_CVPR_2018_paper.html) blocks and an input resolution of 320x320 (and a wider version achieves **83.67\% accuracy** _without SE_). Note that it is trained with 224x224 but tested with 320x320, so that it is still trainable with a global batch size of 256 on a single machine with 8 1080Ti GPUs. If you test it with 224x224, the top-1 accuracy will be 81.82%. It has 1, 8, 14, 24, 1 layers in the 5 stages respectively. The width multipliers are a=2.5 and b=5 (the same as RepVGG-B2). The model name is "RepVGG-D2se". The code for building the model (repvgg.py) and testing with 320x320 (the testing example below) has been updated and the weights have been released at Google Drive and Baidu Cloud. Please check the links below.
147
+
148
+
149
+ ## Example 1: use Structural Re-parameterization like this in your own code
150
+ ```
151
+ from repvgg import repvgg_model_convert, create_RepVGG_A0
152
+ train_model = create_RepVGG_A0(deploy=False)
153
+ train_model.load_state_dict(torch.load('RepVGG-A0-train.pth')) # or train from scratch
154
+ # do whatever you want with train_model
155
+ deploy_model = repvgg_model_convert(train_model, save_path='RepVGG-A0-deploy.pth')
156
+ # do whatever you want with deploy_model
157
+ ```
158
+ or
159
+ ```
160
+ deploy_model = create_RepVGG_A0(deploy=True)
161
+ deploy_model.load_state_dict(torch.load('RepVGG-A0-deploy.pth'))
162
+ # do whatever you want with deploy_model
163
+ ```
164
+ If you use RepVGG as a component of another model, the conversion is as simple as calling **switch_to_deploy** of every RepVGG block.
165
+
166
+
167
+ ## Example 2: use RepVGG as the backbone for downstream tasks
168
+
169
+ I would suggest you use popular frameworks like MMDetection and MMSegmentation. The features from any stage or layer of RepVGG can be fed into the task-specific heads. If you are not familiar with such frameworks and just would like to see a simple example, please check ```example_pspnet.py```, which shows how to use RepVGG as the backbone of PSPNet for semantic segmentation: 1) build a PSPNet with RepVGG backbone, 2) load the ImageNet-pretrained weights, 3) convert the whole model with **switch_to_deploy**, 4) save and use the converted model for inference.
170
+
171
+
172
+
173
+ ## Quantization
174
+
175
+ RepVGG works fine with FP16 but the accuracy may decrease when directly quantized to INT8. If IN8 quantization is essential to your application, we suggest three practical solutions.
176
+
177
+ ### Solution A: RepOptimizer
178
+
179
+ I strongly recommend trying RepOptimizer if quantization is essential to your application. RepOptimizer directly trains a VGG-like model via Gradient Re-parameterization without any structural conversions. Quantizing a VGG-like model trained with RepOptimizer is as easy as quantizing a regular model. RepOptimizer has already been used in YOLOv6.
180
+
181
+ Paper: https://arxiv.org/abs/2205.15242
182
+
183
+ Code: https://github.com/DingXiaoH/RepOptimizers
184
+
185
+ Tutorial provided by the authors of YOLOv6: https://github.com/meituan/YOLOv6/blob/main/docs/tutorial_repopt.md. Great work! Many thanks!
186
+
187
+ ### Solution B: custom quantization-aware training
188
+
189
+ Another choice is is to constrain the equivalent kernel (get_equivalent_kernel_bias() in repvgg.py) to be low-bit (e.g., make every param in {-127, -126, .., 126, 127} for int8), instead of constraining the params of every kernel separately for an ordinary model.
190
+
191
+ ### Solution C: use the off-the-shelf toolboxes
192
+
193
+ (TODO: check and refactor the code of this example)
194
+
195
+ For the simplicity, we can also use the off-the-shelf quantization toolboxes to quantize RepVGG. We use the simple QAT (quantization-aware training) tool in torch.quantization as an example.
196
+
197
+ 1. Given the base model converted into the inference-time structure. We insert BN after the converted 3x3 conv layers because QAT with torch.quantization requires BN. Specifically, we run the model on ImageNet training set and record the mean/std statistics and use them to initialize the BN layers, and initialize BN.gamma/beta accordingly so that the saved model has the same outputs as the inference-time model.
198
+
199
+ ```
200
+ python quantization/convert.py RepVGG-A0.pth RepVGG-A0_base.pth -a RepVGG-A0
201
+ python quantization/insert_bn.py [imagenet-folder] RepVGG-A0_base.pth RepVGG-A0_withBN.pth -a RepVGG-A0 -b 32 -n 40000
202
+ ```
203
+
204
+ 2. Build the model, prepare it for QAT (torch.quantization.prepare_qat), and conduct QAT. This is only an example and the hyper-parameters may not be optimal.
205
+ ```
206
+ python quantization/quant_qat_train.py [imagenet-folder] -j 32 --epochs 20 -b 256 --lr 1e-3 --weight-decay 4e-5 --base-weights RepVGG-A0_withBN.pth --tag quanttest
207
+ ```
208
+
209
+
210
+ ## FAQs
211
+
212
+ **Q**: Is the inference-time model's output the _same_ as the training-time model?
213
+
214
+ **A**: Yes. You can verify that by
215
+ ```
216
+ python tools/verify.py
217
+ ```
218
+
219
+ **Q**: How to use the pretrained RepVGG models for other tasks?
220
+
221
+ **A**: It is better to finetune the training-time RepVGG models on your datasets. Then you should do the conversion after finetuning and before you deploy the models. For example, say you want to use PSPNet for semantic segmentation, you should build a PSPNet with a training-time RepVGG model as the backbone, load pre-trained weights into the backbone, and finetune the PSPNet on your segmentation dataset. Then you should convert the backbone following the code provided in this repo and keep the other task-specific structures (the PSPNet parts, in this case). The pseudo code will be like
222
+ ```
223
+ # train_backbone = create_RepVGG_B2(deploy=False)
224
+ # train_backbone.load_state_dict(torch.load('RepVGG-B2-train.pth'))
225
+ # train_pspnet = build_pspnet(backbone=train_backbone)
226
+ # segmentation_train(train_pspnet)
227
+ # deploy_pspnet = repvgg_model_convert(train_pspnet)
228
+ # segmentation_test(deploy_pspnet)
229
+ ```
230
+ There is an example in **example_pspnet.py**.
231
+
232
+ Finetuning with a converted RepVGG also makes sense if you insert a BN after each conv (please see the quantization example), but the performance may be slightly lower.
233
+
234
+ **Q**: I tried to finetune your model with multiple GPUs but got an error. Why are the names of params like "stage1.0.rbr_dense.conv.weight" in the downloaded weight file but sometimes like "module.stage1.0.rbr_dense.conv.weight" (shown by nn.Module.named_parameters()) in my model?
235
+
236
+ **A**: DistributedDataParallel may prefix "module." to the name of params and cause a mismatch when loading weights by name. The simplest solution is to load the weights (model.load_state_dict(...)) before DistributedDataParallel(model). Otherwise, you may insert "module." before the names like this
237
+ ```
238
+ checkpoint = torch.load(...) # This is just a name-value dict
239
+ ckpt = {('module.' + k) : v for k, v in checkpoint.items()}
240
+ model.load_state_dict(ckpt)
241
+ ```
242
+ Likewise, if the param names in the checkpoint file start with "module." but those in your model do not, you may strip the names like line 50 in test.py.
243
+ ```
244
+ ckpt = {k.replace('module.', ''):v for k,v in checkpoint.items()} # strip the names
245
+ model.load_state_dict(ckpt)
246
+ ```
247
+ **Q**: So a RepVGG model derives the equivalent 3x3 kernels before each forwarding to save computations?
248
+
249
+ **A**: No! More precisely, we do the conversion only once right after training. Then the training-time model can be discarded, and the resultant model only has 3x3 kernels. We only save and use the resultant model.
250
+
251
+
252
+ ## An optional trick with a custom weight decay (deprecated)
253
+
254
+ This is deprecated. Please check ```repvggplus_custom_L2.py```. The intuition is to add regularization on the equivalent kernel. It may work in some cases.
255
+
256
+ The trained model can be downloaded at [Google Drive](https://drive.google.com/file/d/14I1jWU4rS4y0wdxm03SnEVP1Tx6GGfKu/view?usp=sharing) or [Baidu Cloud](https://pan.baidu.com/s/1qFGmgJ6Ir6W3wAcCBQb9-w?pwd=rvgg)
257
+
258
+ The training code should be changed like this:
259
+ ```
260
+ # Build model and data loader as usual
261
+ for samples, targets in enumerate(train_data_loader):
262
+ # ......
263
+ outputs = model(samples) # Your original code
264
+ if type(outputs) is dict:
265
+ # A training-time RepVGGplus outputs a dict. The items are:
266
+ # 'main': the output of the final layer
267
+ # '*aux*': the output of auxiliary classifiers
268
+ # 'L2': the custom L2 regularization term
269
+ loss = WEIGHT_DECAY * 0.5 * outputs['L2']
270
+ for name, pred in outputs.items():
271
+ if name == 'L2':
272
+ pass
273
+ elif 'aux' in name:
274
+ loss += 0.1 * criterion(pred, targets) # Assume "criterion" is cross-entropy for classification
275
+ else:
276
+ loss += criterion(pred, targets)
277
+ else:
278
+ loss = criterion(outputs, targets) # Your original code
279
+ # Backward as usual
280
+ # ......
281
+ ```
282
+
283
+
284
+
285
+ ## Contact
286
+
287
+ **xiaohding@gmail.com** (The original Tsinghua mailbox dxh17@mails.tsinghua.edu.cn will expire in several months)
288
+
289
+ Google Scholar Profile: https://scholar.google.com/citations?user=CIjw0KoAAAAJ&hl=en
290
+
291
+ Homepage: https://dingxiaohan.xyz/
292
+
293
+ My open-sourced papers and repos:
294
+
295
+ The **Structural Re-parameterization Universe**:
296
+
297
+ 1. RepLKNet (CVPR 2022) **Powerful efficient architecture with very large kernels (31x31) and guidelines for using large kernels in model CNNs**\
298
+ [Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs](https://arxiv.org/abs/2203.06717)\
299
+ [code](https://github.com/DingXiaoH/RepLKNet-pytorch).
300
+
301
+ 2. **RepOptimizer** (ICLR 2023) uses **Gradient Re-parameterization** to train powerful models efficiently. The training-time **RepOpt-VGG** is **as simple as the inference-time**. It also addresses the problem of quantization.\
302
+ [Re-parameterizing Your Optimizers rather than Architectures](https://arxiv.org/pdf/2205.15242.pdf)\
303
+ [code](https://github.com/DingXiaoH/RepOptimizers).
304
+
305
+ 3. RepVGG (CVPR 2021) **A super simple and powerful VGG-style ConvNet architecture**. Up to **84.16%** ImageNet top-1 accuracy!\
306
+ [RepVGG: Making VGG-style ConvNets Great Again](https://arxiv.org/abs/2101.03697)\
307
+ [code](https://github.com/DingXiaoH/RepVGG).
308
+
309
+ 4. RepMLP (CVPR 2022) **MLP-style building block and Architecture**\
310
+ [RepMLPNet: Hierarchical Vision MLP with Re-parameterized Locality](https://arxiv.org/abs/2112.11081)\
311
+ [code](https://github.com/DingXiaoH/RepMLP).
312
+
313
+ 5. ResRep (ICCV 2021) **State-of-the-art** channel pruning (Res50, 55\% FLOPs reduction, 76.15\% acc)\
314
+ [ResRep: Lossless CNN Pruning via Decoupling Remembering and Forgetting](https://openaccess.thecvf.com/content/ICCV2021/papers/Ding_ResRep_Lossless_CNN_Pruning_via_Decoupling_Remembering_and_Forgetting_ICCV_2021_paper.pdf)\
315
+ [code](https://github.com/DingXiaoH/ResRep).
316
+
317
+ 6. ACB (ICCV 2019) is a CNN component without any inference-time costs. The first work of our Structural Re-parameterization Universe.\
318
+ [ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks](http://openaccess.thecvf.com/content_ICCV_2019/papers/Ding_ACNet_Strengthening_the_Kernel_Skeletons_for_Powerful_CNN_via_Asymmetric_ICCV_2019_paper.pdf).\
319
+ [code](https://github.com/DingXiaoH/ACNet).
320
+
321
+ 7. DBB (CVPR 2021) is a CNN component with higher performance than ACB and still no inference-time costs. Sometimes I call it ACNet v2 because "DBB" is 2 bits larger than "ACB" in ASCII (lol).\
322
+ [Diverse Branch Block: Building a Convolution as an Inception-like Unit](https://arxiv.org/abs/2103.13425)\
323
+ [code](https://github.com/DingXiaoH/DiverseBranchBlock).
324
+
325
+ **Model compression and acceleration**:
326
+
327
+ 1. (CVPR 2019) Channel pruning: [Centripetal SGD for Pruning Very Deep Convolutional Networks with Complicated Structure](http://openaccess.thecvf.com/content_CVPR_2019/html/Ding_Centripetal_SGD_for_Pruning_Very_Deep_Convolutional_Networks_With_Complicated_CVPR_2019_paper.html)\
328
+ [code](https://github.com/DingXiaoH/Centripetal-SGD)
329
+
330
+ 2. (ICML 2019) Channel pruning: [Approximated Oracle Filter Pruning for Destructive CNN Width Optimization](http://proceedings.mlr.press/v97/ding19a.html)\
331
+ [code](https://github.com/DingXiaoH/AOFP)
332
+
333
+ 3. (NeurIPS 2019) Unstructured pruning: [Global Sparse Momentum SGD for Pruning Very Deep Neural Networks](http://papers.nips.cc/paper/8867-global-sparse-momentum-sgd-for-pruning-very-deep-neural-networks.pdf)\
334
+ [code](https://github.com/DingXiaoH/GSM-SGD)
RepVGG-main/arch.PNG ADDED
RepVGG-main/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .build import build_loader
RepVGG-main/data/build.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf)
3
+ # Github source: https://github.com/DingXiaoH/RepVGG
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # The training script is based on the code of Swin Transformer (https://github.com/microsoft/Swin-Transformer)
6
+ # --------------------------------------------------------
7
+ import torch
8
+ import numpy as np
9
+ import torch.distributed as dist
10
+ from torchvision import datasets, transforms
11
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
12
+ from timm.data import Mixup
13
+ from timm.data import create_transform
14
+ try:
15
+ from timm.data.transforms import str_to_pil_interp as _pil_interp
16
+ except:
17
+ from timm.data.transforms import _pil_interp
18
+ from .cached_image_folder import CachedImageFolder
19
+ from .samplers import SubsetRandomSampler
20
+
21
+
22
+ def build_loader(config):
23
+ config.defrost()
24
+ dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
25
+ config.freeze()
26
+ print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset")
27
+ dataset_val, _ = build_dataset(is_train=False, config=config)
28
+ print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset")
29
+
30
+ num_tasks = dist.get_world_size()
31
+ global_rank = dist.get_rank()
32
+ if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':
33
+ indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
34
+ sampler_train = SubsetRandomSampler(indices)
35
+ else:
36
+ sampler_train = torch.utils.data.DistributedSampler(
37
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
38
+ )
39
+
40
+ if dataset_val is None:
41
+ sampler_val = None
42
+ else:
43
+ indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size()) #TODO
44
+ sampler_val = SubsetRandomSampler(indices)
45
+
46
+ data_loader_train = torch.utils.data.DataLoader(
47
+ dataset_train, sampler=sampler_train,
48
+ batch_size=config.DATA.BATCH_SIZE,
49
+ num_workers=config.DATA.NUM_WORKERS,
50
+ pin_memory=config.DATA.PIN_MEMORY,
51
+ drop_last=True,
52
+ )
53
+
54
+ if dataset_val is None:
55
+ data_loader_val = None
56
+ else:
57
+ data_loader_val = torch.utils.data.DataLoader(
58
+ dataset_val, sampler=sampler_val,
59
+ batch_size=config.DATA.TEST_BATCH_SIZE,
60
+ shuffle=False,
61
+ num_workers=config.DATA.NUM_WORKERS,
62
+ pin_memory=config.DATA.PIN_MEMORY,
63
+ drop_last=False
64
+ )
65
+
66
+ # setup mixup / cutmix
67
+ mixup_fn = None
68
+ mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
69
+ if mixup_active:
70
+ mixup_fn = Mixup(
71
+ mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
72
+ prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
73
+ label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
74
+
75
+ return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
76
+
77
+
78
+ def build_dataset(is_train, config):
79
+ if config.DATA.DATASET == 'imagenet':
80
+ transform = build_transform(is_train, config)
81
+ prefix = 'train' if is_train else 'val'
82
+ if config.DATA.ZIP_MODE:
83
+ ann_file = prefix + "_map.txt"
84
+ prefix = prefix + ".zip@/"
85
+ dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,
86
+ cache_mode=config.DATA.CACHE_MODE if is_train else 'part')
87
+ else:
88
+ import torchvision
89
+ print('use raw ImageNet data')
90
+ dataset = torchvision.datasets.ImageNet(root=config.DATA.DATA_PATH, split='train' if is_train else 'val', transform=transform)
91
+ nb_classes = 1000
92
+
93
+ elif config.DATA.DATASET == 'cf100':
94
+ mean = [0.5070751592371323, 0.48654887331495095, 0.4409178433670343]
95
+ std = [0.2673342858792401, 0.2564384629170883, 0.27615047132568404]
96
+ if is_train:
97
+ transform = transforms.Compose([
98
+ transforms.RandomCrop(32, padding=4),
99
+ transforms.RandomHorizontalFlip(),
100
+ transforms.ToTensor(),
101
+ transforms.Normalize(mean, std)
102
+ ])
103
+ dataset = datasets.CIFAR100(root=config.DATA.DATA_PATH, train=True, download=True, transform=transform)
104
+ else:
105
+ transform = transforms.Compose(
106
+ [transforms.ToTensor(),
107
+ transforms.Normalize(mean, std)])
108
+ dataset = datasets.CIFAR100(root=config.DATA.DATA_PATH, train=False, download=True, transform=transform)
109
+ nb_classes = 100
110
+
111
+ else:
112
+ raise NotImplementedError("We only support ImageNet and CIFAR-100 now.")
113
+
114
+ return dataset, nb_classes
115
+
116
+
117
+ def build_transform(is_train, config):
118
+ resize_im = config.DATA.IMG_SIZE > 32
119
+ if is_train:
120
+ # this should always dispatch to transforms_imagenet_train
121
+
122
+ if config.AUG.PRESET is None:
123
+ transform = create_transform(
124
+ input_size=config.DATA.IMG_SIZE,
125
+ is_training=True,
126
+ color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
127
+ auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
128
+ re_prob=config.AUG.REPROB,
129
+ re_mode=config.AUG.REMODE,
130
+ re_count=config.AUG.RECOUNT,
131
+ interpolation=config.DATA.INTERPOLATION,
132
+ )
133
+ print('=============================== original AUG! ', config.AUG.AUTO_AUGMENT)
134
+ if not resize_im:
135
+ # replace RandomResizedCropAndInterpolation with
136
+ # RandomCrop
137
+ transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
138
+
139
+ elif config.AUG.PRESET.strip() == 'raug15':
140
+ from train.randaug import RandAugPolicy
141
+ transform = transforms.Compose([
142
+ transforms.RandomResizedCrop(config.DATA.IMG_SIZE),
143
+ transforms.RandomHorizontalFlip(),
144
+ RandAugPolicy(magnitude=15),
145
+ transforms.ToTensor(),
146
+ transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
147
+ ])
148
+ print('---------------------- RAND AUG 15 distortion!')
149
+
150
+ elif config.AUG.PRESET.strip() == 'weak':
151
+ transform = transforms.Compose([
152
+ transforms.RandomResizedCrop(config.DATA.IMG_SIZE),
153
+ transforms.RandomHorizontalFlip(),
154
+ transforms.ToTensor(),
155
+ transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
156
+ ])
157
+ elif config.AUG.PRESET.strip() == 'none':
158
+ transform = transforms.Compose([
159
+ transforms.Resize(config.DATA.IMG_SIZE, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
160
+ transforms.CenterCrop(config.DATA.IMG_SIZE),
161
+ transforms.ToTensor(),
162
+ transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
163
+ ])
164
+ else:
165
+ raise ValueError('???' + config.AUG.PRESET)
166
+ print(transform)
167
+ return transform
168
+
169
+ t = []
170
+ if resize_im:
171
+ if config.TEST.CROP:
172
+ size = int((256 / 224) * config.DATA.TEST_SIZE)
173
+ t.append(transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
174
+ # to maintain same ratio w.r.t. 224 images
175
+ )
176
+ t.append(transforms.CenterCrop(config.DATA.TEST_SIZE))
177
+ else:
178
+ # default for testing
179
+ t.append(transforms.Resize(config.DATA.TEST_SIZE, interpolation=_pil_interp(config.DATA.INTERPOLATION)))
180
+ t.append(transforms.CenterCrop(config.DATA.TEST_SIZE))
181
+ t.append(transforms.ToTensor())
182
+ t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
183
+ trans = transforms.Compose(t)
184
+ return trans
RepVGG-main/data/cached_image_folder.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import time
4
+ import torch.distributed as dist
5
+ import torch.utils.data as data
6
+ from PIL import Image
7
+
8
+ from .zipreader import is_zip_path, ZipReader
9
+
10
+
11
+ def has_file_allowed_extension(filename, extensions):
12
+ """Checks if a file is an allowed extension.
13
+ Args:
14
+ filename (string): path to a file
15
+ Returns:
16
+ bool: True if the filename ends with a known image extension
17
+ """
18
+ filename_lower = filename.lower()
19
+ return any(filename_lower.endswith(ext) for ext in extensions)
20
+
21
+
22
+ def find_classes(dir):
23
+ classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
24
+ classes.sort()
25
+ class_to_idx = {classes[i]: i for i in range(len(classes))}
26
+ return classes, class_to_idx
27
+
28
+
29
+ def make_dataset(dir, class_to_idx, extensions):
30
+ images = []
31
+ dir = os.path.expanduser(dir)
32
+ for target in sorted(os.listdir(dir)):
33
+ d = os.path.join(dir, target)
34
+ if not os.path.isdir(d):
35
+ continue
36
+
37
+ for root, _, fnames in sorted(os.walk(d)):
38
+ for fname in sorted(fnames):
39
+ if has_file_allowed_extension(fname, extensions):
40
+ path = os.path.join(root, fname)
41
+ item = (path, class_to_idx[target])
42
+ images.append(item)
43
+
44
+ return images
45
+
46
+
47
+ def make_dataset_with_ann(ann_file, img_prefix, extensions):
48
+ images = []
49
+ with open(ann_file, "r") as f:
50
+ contents = f.readlines()
51
+ for line_str in contents:
52
+ path_contents = [c for c in line_str.split('\t')]
53
+ im_file_name = path_contents[0]
54
+ class_index = int(path_contents[1])
55
+
56
+ assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions
57
+ item = (os.path.join(img_prefix, im_file_name), class_index)
58
+
59
+ images.append(item)
60
+
61
+ return images
62
+
63
+
64
+ class DatasetFolder(data.Dataset):
65
+ """A generic data loader where the samples are arranged in this way: ::
66
+ root/class_x/xxx.ext
67
+ root/class_x/xxy.ext
68
+ root/class_x/xxz.ext
69
+ root/class_y/123.ext
70
+ root/class_y/nsdf3.ext
71
+ root/class_y/asd932_.ext
72
+ Args:
73
+ root (string): Root directory path.
74
+ loader (callable): A function to load a sample given its path.
75
+ extensions (list[string]): A list of allowed extensions.
76
+ transform (callable, optional): A function/transform that takes in
77
+ a sample and returns a transformed version.
78
+ E.g, ``transforms.RandomCrop`` for images.
79
+ target_transform (callable, optional): A function/transform that takes
80
+ in the target and transforms it.
81
+ Attributes:
82
+ samples (list): List of (sample path, class_index) tuples
83
+ """
84
+
85
+ def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None,
86
+ cache_mode="no"):
87
+ # image folder mode
88
+ if ann_file == '':
89
+ _, class_to_idx = find_classes(root)
90
+ samples = make_dataset(root, class_to_idx, extensions)
91
+ # zip mode
92
+ else:
93
+ samples = make_dataset_with_ann(os.path.join(root, ann_file),
94
+ os.path.join(root, img_prefix),
95
+ extensions)
96
+
97
+ if len(samples) == 0:
98
+ raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" +
99
+ "Supported extensions are: " + ",".join(extensions)))
100
+
101
+ self.root = root
102
+ self.loader = loader
103
+ self.extensions = extensions
104
+
105
+ self.samples = samples
106
+ self.labels = [y_1k for _, y_1k in samples]
107
+ self.classes = list(set(self.labels))
108
+
109
+ self.transform = transform
110
+ self.target_transform = target_transform
111
+
112
+ self.cache_mode = cache_mode
113
+ if self.cache_mode != "no":
114
+ self.init_cache()
115
+
116
+ def init_cache(self):
117
+ assert self.cache_mode in ["part", "full"]
118
+ n_sample = len(self.samples)
119
+ global_rank = dist.get_rank()
120
+ world_size = dist.get_world_size()
121
+
122
+ samples_bytes = [None for _ in range(n_sample)]
123
+ start_time = time.time()
124
+ for index in range(n_sample):
125
+ if index % (n_sample // 10) == 0:
126
+ t = time.time() - start_time
127
+ print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block')
128
+ start_time = time.time()
129
+ path, target = self.samples[index]
130
+ if self.cache_mode == "full":
131
+ samples_bytes[index] = (ZipReader.read(path), target)
132
+ elif self.cache_mode == "part" and index % world_size == global_rank:
133
+ samples_bytes[index] = (ZipReader.read(path), target)
134
+ else:
135
+ samples_bytes[index] = (path, target)
136
+ self.samples = samples_bytes
137
+
138
+ def __getitem__(self, index):
139
+ """
140
+ Args:
141
+ index (int): Index
142
+ Returns:
143
+ tuple: (sample, target) where target is class_index of the target class.
144
+ """
145
+ path, target = self.samples[index]
146
+ sample = self.loader(path)
147
+ if self.transform is not None:
148
+ sample = self.transform(sample)
149
+ if self.target_transform is not None:
150
+ target = self.target_transform(target)
151
+
152
+ return sample, target
153
+
154
+ def __len__(self):
155
+ return len(self.samples)
156
+
157
+ def __repr__(self):
158
+ fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
159
+ fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
160
+ fmt_str += ' Root Location: {}\n'.format(self.root)
161
+ tmp = ' Transforms (if any): '
162
+ fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
163
+ tmp = ' Target Transforms (if any): '
164
+ fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
165
+ return fmt_str
166
+
167
+
168
+ IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
169
+
170
+
171
+ def pil_loader(path):
172
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
173
+ if isinstance(path, bytes):
174
+ img = Image.open(io.BytesIO(path))
175
+ elif is_zip_path(path):
176
+ data = ZipReader.read(path)
177
+ img = Image.open(io.BytesIO(data))
178
+ else:
179
+ with open(path, 'rb') as f:
180
+ img = Image.open(f)
181
+ return img.convert('RGB')
182
+
183
+
184
+ def accimage_loader(path):
185
+ import accimage
186
+ try:
187
+ return accimage.Image(path)
188
+ except IOError:
189
+ # Potentially a decoding problem, fall back to PIL.Image
190
+ return pil_loader(path)
191
+
192
+
193
+ def default_img_loader(path):
194
+ from torchvision import get_image_backend
195
+ if get_image_backend() == 'accimage':
196
+ return accimage_loader(path)
197
+ else:
198
+ return pil_loader(path)
199
+
200
+
201
+ class CachedImageFolder(DatasetFolder):
202
+ """A generic data loader where the images are arranged in this way: ::
203
+ root/dog/xxx.png
204
+ root/dog/xxy.png
205
+ root/dog/xxz.png
206
+ root/cat/123.png
207
+ root/cat/nsdf3.png
208
+ root/cat/asd932_.png
209
+ Args:
210
+ root (string): Root directory path.
211
+ transform (callable, optional): A function/transform that takes in an PIL image
212
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
213
+ target_transform (callable, optional): A function/transform that takes in the
214
+ target and transforms it.
215
+ loader (callable, optional): A function to load an image given its path.
216
+ Attributes:
217
+ imgs (list): List of (image path, class_index) tuples
218
+ """
219
+
220
+ def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None,
221
+ loader=default_img_loader, cache_mode="no"):
222
+ super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
223
+ ann_file=ann_file, img_prefix=img_prefix,
224
+ transform=transform, target_transform=target_transform,
225
+ cache_mode=cache_mode)
226
+ self.imgs = self.samples
227
+
228
+ def __getitem__(self, index):
229
+ """
230
+ Args:
231
+ index (int): Index
232
+ Returns:
233
+ tuple: (image, target) where target is class_index of the target class.
234
+ """
235
+ path, target = self.samples[index]
236
+ image = self.loader(path)
237
+ if self.transform is not None:
238
+ img = self.transform(image)
239
+ else:
240
+ img = image
241
+ if self.target_transform is not None:
242
+ target = self.target_transform(target)
243
+
244
+ return img, target
RepVGG-main/data/lmdb_dataset.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ from PIL import Image
4
+ import six
5
+ import lmdb
6
+ import warnings
7
+ warnings.simplefilter(action='ignore', category=FutureWarning)
8
+ import pyarrow as pa
9
+ import numpy as np
10
+ import torch.utils.data as data
11
+ from torch.utils.data import DataLoader
12
+ from torchvision.datasets import ImageFolder
13
+
14
+ train_lmdb_path = '/apdcephfs/share_1290939/0_public_datasets/imageNet_2012/train.lmdb'
15
+ val_lmdb_path = '/apdcephfs/share_1290939/0_public_datasets/imageNet_2012/val.lmdb'
16
+
17
+ # from data.lmdb_dataset import ImageFolderLMDB, train_lmdb_path, val_lmdb_path
18
+ # lmdb_path = train_lmdb_path if is_train else val_lmdb_path
19
+ # dataset = ImageFolderLMDB(db_path=lmdb_path, transform=transform)
20
+
21
+ def loads_pyarrow(buf):
22
+ """
23
+ Args:
24
+ buf: the output of `dumps`.
25
+ """
26
+ return pa.deserialize(buf)
27
+
28
+
29
+ class ImageFolderLMDB(data.Dataset):
30
+ def __init__(self, db_path, transform=None, target_transform=None):
31
+ self.db_path = db_path
32
+ self.env = lmdb.open(db_path, subdir=osp.isdir(db_path),
33
+ readonly=True, lock=False,
34
+ readahead=False, meminit=False)
35
+ with self.env.begin(write=False) as txn:
36
+ self.length = loads_pyarrow(txn.get(b'__len__'))
37
+ self.keys = loads_pyarrow(txn.get(b'__keys__'))
38
+
39
+ self.transform = transform
40
+ self.target_transform = target_transform
41
+
42
+ def __getstate__(self):
43
+ state = self.__dict__
44
+ state["env"] = None
45
+ return state
46
+
47
+ def __setstate__(self, state):
48
+ self.__dict__ = state
49
+ self.env = lmdb.open(self.db_path, subdir=osp.isdir(self.db_path),
50
+ readonly=True, lock=False,
51
+ readahead=False, meminit=False)
52
+ with self.env.begin(write=False) as txn:
53
+ self.length = loads_pyarrow(txn.get(b'__len__'))
54
+ self.keys = loads_pyarrow(txn.get(b'__keys__'))
55
+
56
+ def __getitem__(self, index):
57
+ env = self.env
58
+ with env.begin(write=False) as txn:
59
+ byteflow = txn.get(self.keys[index])
60
+
61
+ unpacked = loads_pyarrow(byteflow)
62
+
63
+ # load img
64
+ imgbuf = unpacked[0]
65
+ buf = six.BytesIO()
66
+ buf.write(imgbuf)
67
+ buf.seek(0)
68
+ img = Image.open(buf).convert('RGB')
69
+ if self.transform is not None:
70
+ img = self.transform(img)
71
+
72
+ # load label
73
+ target = unpacked[1]
74
+ if self.target_transform is not None:
75
+ target = self.transform(target)
76
+
77
+ return img, target
78
+ # if self.transform is not None:
79
+ # img = self.transform(img)
80
+ #
81
+ # # im2arr = np.array(img)
82
+ #
83
+ # if self.target_transform is not None:
84
+ # target = self.target_transform(target)
85
+ #
86
+ # return img, target
87
+ # return im2arr, target
88
+
89
+ def __len__(self):
90
+ return self.length
91
+
92
+ def __repr__(self):
93
+ return self.__class__.__name__ + ' (' + self.db_path + ')'
94
+
95
+
96
+ def raw_reader(path):
97
+ with open(path, 'rb') as f:
98
+ bin_data = f.read()
99
+ return bin_data
100
+
101
+
102
+ def dumps_pyarrow(obj):
103
+ """
104
+ Serialize an object.
105
+ Returns:
106
+ Implementation-dependent bytes-like object
107
+ """
108
+ return pa.serialize(obj).to_buffer()
109
+
110
+
111
+ def folder2lmdb(dpath, name="train", write_frequency=5000):
112
+ directory = osp.expanduser(osp.join(dpath, name))
113
+ print("Loading dataset from %s" % directory)
114
+ dataset = ImageFolder(directory, loader=raw_reader)
115
+ data_loader = DataLoader(dataset, num_workers=4, collate_fn=lambda x: x)
116
+
117
+ lmdb_path = osp.join(dpath, "%s.lmdb" % name)
118
+ isdir = os.path.isdir(lmdb_path)
119
+
120
+ print("Generate LMDB to %s" % lmdb_path)
121
+ db = lmdb.open(lmdb_path, subdir=isdir,
122
+ map_size=1099511627776 * 2, readonly=False,
123
+ meminit=False, map_async=True)
124
+
125
+ txn = db.begin(write=True)
126
+ for idx, data in enumerate(data_loader):
127
+ image, label = data[0]
128
+
129
+ txn.put(u'{}'.format(idx).encode('ascii'), dumps_pyarrow((image, label)))
130
+ if idx % write_frequency == 0:
131
+ print("[%d/%d]" % (idx, len(data_loader)))
132
+ txn.commit()
133
+ txn = db.begin(write=True)
134
+
135
+ # finish iterating through dataset
136
+ txn.commit()
137
+ keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
138
+ with db.begin(write=True) as txn:
139
+ txn.put(b'__keys__', dumps_pyarrow(keys))
140
+ txn.put(b'__len__', dumps_pyarrow(len(keys)))
141
+
142
+ print("Flushing database ...")
143
+ db.sync()
144
+ db.close()
145
+
146
+
147
+
148
+
149
+ if __name__ == "__main__":
150
+ # lmdb_path = '/apdcephfs/share_1016399/0_public_datasets/imageNet_2012/train.lmdb'
151
+ # from lmdb_dataset import ImageFolderLMDB
152
+ # dataset = ImageFolderLMDB(db_path=lmdb_path)
153
+ # for x, y in dataset:
154
+ # print(type(x), type(y))
155
+ # exit()
156
+
157
+ import argparse
158
+ parser = argparse.ArgumentParser()
159
+ parser.add_argument('--dir', type=str, required=True, help="The dataset directory to process")
160
+ args = parser.parse_args()
161
+ # generate lmdb
162
+ path = args.dir
163
+ folder2lmdb(path, name="train")
164
+ folder2lmdb(path, name="val")
RepVGG-main/data/samplers.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class SubsetRandomSampler(torch.utils.data.Sampler):
4
+ r"""Samples elements randomly from a given list of indices, without replacement.
5
+
6
+ Arguments:
7
+ indices (sequence): a sequence of indices
8
+ """
9
+
10
+ def __init__(self, indices):
11
+ self.epoch = 0
12
+ self.indices = indices
13
+
14
+ def __iter__(self):
15
+ return (self.indices[i] for i in torch.randperm(len(self.indices)))
16
+
17
+ def __len__(self):
18
+ return len(self.indices)
19
+
20
+ def set_epoch(self, epoch):
21
+ self.epoch = epoch
RepVGG-main/data/zipreader.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import zipfile
3
+ import io
4
+ import numpy as np
5
+ from PIL import Image
6
+ from PIL import ImageFile
7
+
8
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
9
+
10
+
11
+ def is_zip_path(img_or_path):
12
+ """judge if this is a zip path"""
13
+ return '.zip@' in img_or_path
14
+
15
+
16
+ class ZipReader(object):
17
+ """A class to read zipped files"""
18
+ zip_bank = dict()
19
+
20
+ def __init__(self):
21
+ super(ZipReader, self).__init__()
22
+
23
+ @staticmethod
24
+ def get_zipfile(path):
25
+ zip_bank = ZipReader.zip_bank
26
+ if path not in zip_bank:
27
+ zfile = zipfile.ZipFile(path, 'r')
28
+ zip_bank[path] = zfile
29
+ return zip_bank[path]
30
+
31
+ @staticmethod
32
+ def split_zip_style_path(path):
33
+ pos_at = path.index('@')
34
+ assert pos_at != -1, "character '@' is not found from the given path '%s'" % path
35
+
36
+ zip_path = path[0: pos_at]
37
+ folder_path = path[pos_at + 1:]
38
+ folder_path = str.strip(folder_path, '/')
39
+ return zip_path, folder_path
40
+
41
+ @staticmethod
42
+ def list_folder(path):
43
+ zip_path, folder_path = ZipReader.split_zip_style_path(path)
44
+
45
+ zfile = ZipReader.get_zipfile(zip_path)
46
+ folder_list = []
47
+ for file_foler_name in zfile.namelist():
48
+ file_foler_name = str.strip(file_foler_name, '/')
49
+ if file_foler_name.startswith(folder_path) and \
50
+ len(os.path.splitext(file_foler_name)[-1]) == 0 and \
51
+ file_foler_name != folder_path:
52
+ if len(folder_path) == 0:
53
+ folder_list.append(file_foler_name)
54
+ else:
55
+ folder_list.append(file_foler_name[len(folder_path) + 1:])
56
+
57
+ return folder_list
58
+
59
+ @staticmethod
60
+ def list_files(path, extension=None):
61
+ if extension is None:
62
+ extension = ['.*']
63
+ zip_path, folder_path = ZipReader.split_zip_style_path(path)
64
+
65
+ zfile = ZipReader.get_zipfile(zip_path)
66
+ file_lists = []
67
+ for file_foler_name in zfile.namelist():
68
+ file_foler_name = str.strip(file_foler_name, '/')
69
+ if file_foler_name.startswith(folder_path) and \
70
+ str.lower(os.path.splitext(file_foler_name)[-1]) in extension:
71
+ if len(folder_path) == 0:
72
+ file_lists.append(file_foler_name)
73
+ else:
74
+ file_lists.append(file_foler_name[len(folder_path) + 1:])
75
+
76
+ return file_lists
77
+
78
+ @staticmethod
79
+ def read(path):
80
+ zip_path, path_img = ZipReader.split_zip_style_path(path)
81
+ zfile = ZipReader.get_zipfile(zip_path)
82
+ data = zfile.read(path_img)
83
+ return data
84
+
85
+ @staticmethod
86
+ def imread(path):
87
+ zip_path, path_img = ZipReader.split_zip_style_path(path)
88
+ zfile = ZipReader.get_zipfile(zip_path)
89
+ data = zfile.read(path_img)
90
+ try:
91
+ im = Image.open(io.BytesIO(data))
92
+ except:
93
+ print("ERROR IMG LOADED: ", path_img)
94
+ random_img = np.random.rand(224, 224, 3) * 255
95
+ im = Image.fromarray(np.uint8(random_img))
96
+ return im
RepVGG-main/example_pspnet.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from repvgg import get_RepVGG_func_by_name
5
+
6
+ # The PSPNet parts are from
7
+ # https://github.com/hszhao/semseg
8
+
9
+ class PPM(nn.Module):
10
+ def __init__(self, in_dim, reduction_dim, bins, BatchNorm):
11
+ super(PPM, self).__init__()
12
+ self.features = []
13
+ for bin in bins:
14
+ self.features.append(nn.Sequential(
15
+ nn.AdaptiveAvgPool2d(bin),
16
+ nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
17
+ BatchNorm(reduction_dim),
18
+ nn.ReLU(inplace=True)
19
+ ))
20
+ self.features = nn.ModuleList(self.features)
21
+
22
+ def forward(self, x):
23
+ x_size = x.size()
24
+ out = [x]
25
+ for f in self.features:
26
+ out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True))
27
+ return torch.cat(out, 1)
28
+
29
+
30
+ class PSPNet(nn.Module):
31
+ def __init__(self,
32
+ backbone_name, backbone_file, deploy,
33
+ bins=(1, 2, 3, 6), dropout=0.1, classes=2,
34
+ zoom_factor=8, use_ppm=True, criterion=nn.CrossEntropyLoss(ignore_index=255), BatchNorm=nn.BatchNorm2d,
35
+ pretrained=True):
36
+ super(PSPNet, self).__init__()
37
+ assert 2048 % len(bins) == 0
38
+ assert classes > 1
39
+ assert zoom_factor in [1, 2, 4, 8]
40
+ self.zoom_factor = zoom_factor
41
+ self.use_ppm = use_ppm
42
+ self.criterion = criterion
43
+
44
+ repvgg_fn = get_RepVGG_func_by_name(backbone_name)
45
+ backbone = repvgg_fn(deploy)
46
+ if pretrained:
47
+ checkpoint = torch.load(backbone_file)
48
+ if 'state_dict' in checkpoint:
49
+ checkpoint = checkpoint['state_dict']
50
+ ckpt = {k.replace('module.', ''): v for k, v in checkpoint.items()} # strip the names
51
+ backbone.load_state_dict(ckpt)
52
+
53
+ self.layer0, self.layer1, self.layer2, self.layer3, self.layer4 = backbone.stage0, backbone.stage1, backbone.stage2, backbone.stage3, backbone.stage4
54
+
55
+ # The last two stages should have stride=1 for semantic segmentation
56
+ # Note that the stride of 1x1 should be the same as the 3x3
57
+ # Use dilation following the implementation of PSPNet
58
+ secondlast_channel = 0
59
+ for n, m in self.layer3.named_modules():
60
+ if ('rbr_dense' in n or 'rbr_reparam' in n) and isinstance(m, nn.Conv2d):
61
+ m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
62
+ print('change dilation, padding, stride of ', n)
63
+ secondlast_channel = m.out_channels
64
+ elif 'rbr_1x1' in n and isinstance(m, nn.Conv2d):
65
+ m.stride = (1, 1)
66
+ print('change stride of ', n)
67
+ last_channel = 0
68
+ for n, m in self.layer4.named_modules():
69
+ if ('rbr_dense' in n or 'rbr_reparam' in n) and isinstance(m, nn.Conv2d):
70
+ m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
71
+ print('change dilation, padding, stride of ', n)
72
+ last_channel = m.out_channels
73
+ elif 'rbr_1x1' in n and isinstance(m, nn.Conv2d):
74
+ m.stride = (1, 1)
75
+ print('change stride of ', n)
76
+
77
+ fea_dim = last_channel
78
+ aux_in = secondlast_channel
79
+
80
+ if use_ppm:
81
+ self.ppm = PPM(fea_dim, int(fea_dim/len(bins)), bins, BatchNorm)
82
+ fea_dim *= 2
83
+
84
+ self.cls = nn.Sequential(
85
+ nn.Conv2d(fea_dim, 512, kernel_size=3, padding=1, bias=False),
86
+ BatchNorm(512),
87
+ nn.ReLU(inplace=True),
88
+ nn.Dropout2d(p=dropout),
89
+ nn.Conv2d(512, classes, kernel_size=1)
90
+ )
91
+ if self.training:
92
+ self.aux = nn.Sequential(
93
+ nn.Conv2d(aux_in, 256, kernel_size=3, padding=1, bias=False),
94
+ BatchNorm(256),
95
+ nn.ReLU(inplace=True),
96
+ nn.Dropout2d(p=dropout),
97
+ nn.Conv2d(256, classes, kernel_size=1)
98
+ )
99
+
100
+ def forward(self, x, y=None):
101
+ x_size = x.size()
102
+ assert (x_size[2]-1) % 8 == 0 and (x_size[3]-1) % 8 == 0
103
+ h = int((x_size[2] - 1) / 8 * self.zoom_factor + 1)
104
+ w = int((x_size[3] - 1) / 8 * self.zoom_factor + 1)
105
+
106
+ x = self.layer0(x)
107
+ x = self.layer1(x)
108
+ x = self.layer2(x)
109
+ x_tmp = self.layer3(x)
110
+ x = self.layer4(x_tmp)
111
+
112
+ if self.use_ppm:
113
+ x = self.ppm(x)
114
+ x = self.cls(x)
115
+ if self.zoom_factor != 1:
116
+ x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
117
+
118
+ if self.training:
119
+ aux = self.aux(x_tmp)
120
+ if self.zoom_factor != 1:
121
+ aux = F.interpolate(aux, size=(h, w), mode='bilinear', align_corners=True)
122
+ main_loss = self.criterion(x, y)
123
+ aux_loss = self.criterion(aux, y)
124
+ return x.max(1)[1], main_loss, aux_loss
125
+ else:
126
+ return x
127
+
128
+
129
+ if __name__ == '__main__':
130
+ # 1. Build the PSPNet with RepVGG backbone. Download the ImageNet-pretrained weight file and load it.
131
+ model = PSPNet(backbone_name='RepVGG-A0', backbone_file='RepVGG-A0-train.pth', deploy=False, classes=19, pretrained=True)
132
+
133
+ # 2. Train it
134
+ # seg_train(model)
135
+
136
+ # 3. Convert and check the equivalence
137
+ input = torch.rand(4, 3, 713, 713)
138
+ model.eval()
139
+ print(model)
140
+ y_train = model(input)
141
+ for module in model.modules():
142
+ if hasattr(module, 'switch_to_deploy'):
143
+ module.switch_to_deploy()
144
+ y_deploy = model(input)
145
+ print('output is ', y_deploy.size())
146
+ print('=================== The diff is')
147
+ print(((y_deploy - y_train) ** 2).sum())
148
+
149
+ # 4. Save the converted model
150
+ torch.save(model.state_dict(), 'PSPNet-RepVGG-A0-deploy.pth')
151
+ del model # Or do whatever you want with it
152
+
153
+ # 5. For inference, load the saved model. There is no need to load the ImageNet-pretrained weights again.
154
+ deploy_model = PSPNet(backbone_name='RepVGG-A0', backbone_file=None, deploy=True, classes=19, pretrained=False)
155
+ deploy_model.eval()
156
+ deploy_model.load_state_dict(torch.load('PSPNet-RepVGG-A0-deploy.pth'))
157
+
158
+ # 6. Check again or do whatever you want
159
+ y_deploy = deploy_model(input)
160
+ print('=================== The diff is')
161
+ print(((y_deploy - y_train) ** 2).sum())
RepVGG-main/jizhi_submit_train_repvgg.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import os
4
+ import json
5
+
6
+ parser = argparse.ArgumentParser('JIZHI submit', add_help=False)
7
+ parser.add_argument('arch', default=None, type=str)
8
+ parser.add_argument('tag', default=None, type=str)
9
+ parser.add_argument('--config', default='/apdcephfs_cq2/share_1290939/xiaohanding/cnt/default_V100x8_elastic_config.json', type=str,
10
+ help='config file')
11
+
12
+
13
+ args = parser.parse_args()
14
+ run_dir = f'{args.arch}_{args.tag}'
15
+
16
+ cmd = f'python3 -m torch.distributed.launch --nproc_per_node 8 --master_port 12349 main.py ' \
17
+ f'--arch {args.arch} --batch-size 32 --tag {args.tag} --output-dir /apdcephfs_cq2/share_1290939/xiaohanding/swin_exps/{args.arch}_{args.tag} --opts TRAIN.EPOCHS 120 TRAIN.BASE_LR 0.1 TRAIN.WEIGHT_DECAY 4e-5 TRAIN.WARMUP_EPOCHS 5 MODEL.LABEL_SMOOTHING 0.1 AUG.PRESET raug15 DATA.DATASET imagenet'
18
+
19
+ os.system('cd /apdcephfs_cq2/share_1290939/xiaohanding/RepVGG/')
20
+ os.system(f'mkdir runs/{run_dir}')
21
+ with open(f'runs/{run_dir}/start.sh', 'w') as f:
22
+ f.write(cmd)
23
+ with open(args.config, 'r') as f:
24
+ json_content = json.load(f)
25
+ json_content['model_local_file_path'] = f'/apdcephfs_cq2/share_1290939/xiaohanding/RepVGG/runs/{run_dir}'
26
+ config_file_path = f'/apdcephfs_cq2/share_1290939/xiaohanding/RepVGG/runs/{run_dir}/config.json'
27
+ with open(config_file_path, 'w') as f:
28
+ json.dump(json_content, f)
29
+
30
+ os.system(f'cp *.py runs/{run_dir}/')
31
+ os.system(f'cp -r data runs/{run_dir}/')
32
+ os.system(f'cp -r train runs/{run_dir}/')
33
+ os.system(f'cd runs/{run_dir}')
34
+ os.system(f'jizhi_client start -scfg {config_file_path}')
RepVGG-main/main.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf)
3
+ # Github source: https://github.com/DingXiaoH/RepVGG
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # The training script is based on the code of Swin Transformer (https://github.com/microsoft/Swin-Transformer)
6
+ # --------------------------------------------------------
7
+ import time
8
+ import argparse
9
+ import datetime
10
+ import numpy as np
11
+ import torch
12
+ import torch.backends.cudnn as cudnn
13
+ import torch.distributed as dist
14
+ from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
15
+ from timm.utils import accuracy, AverageMeter
16
+ from train.config import get_config
17
+ from data import build_loader
18
+ from train.lr_scheduler import build_scheduler
19
+ from train.logger import create_logger
20
+ from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor, save_latest, update_model_ema, unwrap_model
21
+ import copy
22
+ from train.optimizer import build_optimizer
23
+ from repvggplus import create_RepVGGplus_by_name
24
+
25
+ try:
26
+ # noinspection PyUnresolvedReferences
27
+ from apex import amp
28
+ except ImportError:
29
+ amp = None
30
+
31
+ def parse_option():
32
+ parser = argparse.ArgumentParser('RepOpt-VGG training script built on the codebase of Swin Transformer', add_help=False)
33
+ parser.add_argument(
34
+ "--opts",
35
+ help="Modify config options by adding 'KEY VALUE' pairs. ",
36
+ default=None,
37
+ nargs='+',
38
+ )
39
+
40
+ # easy config modification
41
+ parser.add_argument('--arch', default=None, type=str, help='arch name')
42
+ parser.add_argument('--batch-size', default=128, type=int, help="batch size for single GPU")
43
+ parser.add_argument('--data-path', default='/your/path/to/dataset', type=str, help='path to dataset')
44
+ parser.add_argument('--scales-path', default=None, type=str, help='path to the trained Hyper-Search model')
45
+ parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
46
+ parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
47
+ help='no: no cache, '
48
+ 'full: cache all data, '
49
+ 'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
50
+ parser.add_argument('--resume', help='resume from checkpoint')
51
+ parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
52
+ parser.add_argument('--use-checkpoint', action='store_true',
53
+ help="whether to use gradient checkpointing to save memory")
54
+ parser.add_argument('--amp-opt-level', type=str, default='O0', choices=['O0', 'O1', 'O2'], #TODO Note: use amp if you have it
55
+ help='mixed precision opt level, if O0, no amp is used')
56
+ parser.add_argument('--output', default='/your/path/to/save/dir', type=str, metavar='PATH',
57
+ help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
58
+ parser.add_argument('--tag', help='tag of experiment')
59
+ parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
60
+ parser.add_argument('--throughput', action='store_true', help='Test throughput only')
61
+
62
+ # distributed training
63
+ parser.add_argument("--local_rank", type=int, default=0, help='local rank for DistributedDataParallel')
64
+
65
+ args, unparsed = parser.parse_known_args()
66
+
67
+ config = get_config(args)
68
+
69
+ return args, config
70
+
71
+
72
+
73
+
74
+
75
+ def main(config):
76
+ dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)
77
+
78
+ logger.info(f"Creating model:{config.MODEL.ARCH}")
79
+
80
+ model = create_RepVGGplus_by_name(config.MODEL.ARCH, deploy=False, use_checkpoint=args.use_checkpoint)
81
+ optimizer = build_optimizer(config, model)
82
+
83
+ logger.info(str(model))
84
+ model.cuda()
85
+
86
+ if torch.cuda.device_count() > 1:
87
+ if config.AMP_OPT_LEVEL != "O0":
88
+ model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL)
89
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK],
90
+ broadcast_buffers=False)
91
+ model_without_ddp = model.module
92
+ else:
93
+ if config.AMP_OPT_LEVEL != "O0":
94
+ model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL)
95
+ model_without_ddp = model
96
+
97
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
98
+ logger.info(f"number of params: {n_parameters}")
99
+ if hasattr(model_without_ddp, 'flops'):
100
+ flops = model_without_ddp.flops()
101
+ logger.info(f"number of GFLOPs: {flops / 1e9}")
102
+
103
+ if config.THROUGHPUT_MODE:
104
+ throughput(data_loader_val, model, logger)
105
+ return
106
+
107
+ if config.EVAL_MODE:
108
+ load_weights(model, config.MODEL.RESUME)
109
+ acc1, acc5, loss = validate(config, data_loader_val, model)
110
+ logger.info(f"Only eval. top-1 acc, top-5 acc, loss: {acc1:.3f}, {acc5:.3f}, {loss:.5f}")
111
+ return
112
+
113
+ lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))
114
+
115
+ if config.AUG.MIXUP > 0.:
116
+ # smoothing is handled with mixup label transform
117
+ criterion = SoftTargetCrossEntropy()
118
+ elif config.MODEL.LABEL_SMOOTHING > 0.:
119
+ criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)
120
+ else:
121
+ criterion = torch.nn.CrossEntropyLoss()
122
+
123
+ max_accuracy = 0.0
124
+ max_ema_accuracy = 0.0
125
+
126
+ if config.TRAIN.EMA_ALPHA > 0 and (not config.EVAL_MODE) and (not config.THROUGHPUT_MODE):
127
+ model_ema = copy.deepcopy(model)
128
+ else:
129
+ model_ema = None
130
+
131
+ if config.TRAIN.AUTO_RESUME:
132
+ resume_file = auto_resume_helper(config.OUTPUT)
133
+ if resume_file:
134
+ if config.MODEL.RESUME:
135
+ logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
136
+ config.defrost()
137
+ config.MODEL.RESUME = resume_file
138
+ config.freeze()
139
+ logger.info(f'auto resuming from {resume_file}')
140
+ else:
141
+ logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
142
+
143
+ if (not config.THROUGHPUT_MODE) and config.MODEL.RESUME:
144
+ max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger, model_ema=model_ema)
145
+
146
+
147
+ logger.info("Start training")
148
+ start_time = time.time()
149
+ for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
150
+ data_loader_train.sampler.set_epoch(epoch)
151
+
152
+ train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, model_ema=model_ema)
153
+ if dist.get_rank() == 0:
154
+ save_latest(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger, model_ema=model_ema)
155
+ if epoch % config.SAVE_FREQ == 0:
156
+ save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger, model_ema=model_ema)
157
+
158
+ if epoch % config.SAVE_FREQ == 0 or epoch >= (config.TRAIN.EPOCHS - 10):
159
+
160
+ if data_loader_val is not None:
161
+ acc1, acc5, loss = validate(config, data_loader_val, model)
162
+ logger.info(f"Accuracy of the network at epoch {epoch}: {acc1:.3f}%")
163
+ max_accuracy = max(max_accuracy, acc1)
164
+ logger.info(f'Max accuracy: {max_accuracy:.2f}%')
165
+ if max_accuracy == acc1 and dist.get_rank() == 0:
166
+ save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger,
167
+ is_best=True, model_ema=model_ema)
168
+
169
+ if model_ema is not None:
170
+ if data_loader_val is not None:
171
+ acc1, acc5, loss = validate(config, data_loader_val, model_ema)
172
+ logger.info(f"EMAAccuracy of the network at epoch {epoch} test images: {acc1:.3f}%")
173
+ max_ema_accuracy = max(max_ema_accuracy, acc1)
174
+ logger.info(f'EMAMax accuracy: {max_ema_accuracy:.2f}%')
175
+ if max_ema_accuracy == acc1 and dist.get_rank() == 0:
176
+ best_ema_path = os.path.join(config.OUTPUT, 'best_ema.pth')
177
+ logger.info(f"{best_ema_path} best EMA saving......")
178
+ torch.save(unwrap_model(model_ema).state_dict(), best_ema_path)
179
+ else:
180
+ latest_ema_path = os.path.join(config.OUTPUT, 'latest_ema.pth')
181
+ logger.info(f"{latest_ema_path} latest EMA saving......")
182
+ torch.save(unwrap_model(model_ema).state_dict(), latest_ema_path)
183
+
184
+ total_time = time.time() - start_time
185
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
186
+ logger.info('Training time {}'.format(total_time_str))
187
+
188
+
189
+ def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, model_ema=None):
190
+ model.train()
191
+ optimizer.zero_grad()
192
+
193
+ num_steps = len(data_loader)
194
+ batch_time = AverageMeter()
195
+ loss_meter = AverageMeter()
196
+ norm_meter = AverageMeter()
197
+
198
+ start = time.time()
199
+ end = time.time()
200
+ for idx, (samples, targets) in enumerate(data_loader):
201
+ samples = samples.cuda(non_blocking=True)
202
+ targets = targets.cuda(non_blocking=True)
203
+
204
+ if mixup_fn is not None:
205
+ samples, targets = mixup_fn(samples, targets)
206
+
207
+ outputs = model(samples)
208
+
209
+ if type(outputs) is dict:
210
+ loss = 0.0
211
+ for name, pred in outputs.items():
212
+ if 'aux' in name:
213
+ loss += 0.1 * criterion(pred, targets)
214
+ else:
215
+ loss += criterion(pred, targets)
216
+ else:
217
+ loss = criterion(outputs, targets)
218
+
219
+ if config.TRAIN.ACCUMULATION_STEPS > 1:
220
+
221
+ loss = loss / config.TRAIN.ACCUMULATION_STEPS
222
+ if config.AMP_OPT_LEVEL != "O0":
223
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
224
+ scaled_loss.backward()
225
+ if config.TRAIN.CLIP_GRAD:
226
+ grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)
227
+ else:
228
+ grad_norm = get_grad_norm(amp.master_params(optimizer))
229
+ else:
230
+ loss.backward()
231
+ if config.TRAIN.CLIP_GRAD:
232
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
233
+ else:
234
+ grad_norm = get_grad_norm(model.parameters())
235
+ if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
236
+ optimizer.step()
237
+ optimizer.zero_grad()
238
+ lr_scheduler.step_update(epoch * num_steps + idx)
239
+
240
+ else:
241
+
242
+ optimizer.zero_grad()
243
+ if config.AMP_OPT_LEVEL != "O0":
244
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
245
+ scaled_loss.backward()
246
+ if config.TRAIN.CLIP_GRAD:
247
+ grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)
248
+ else:
249
+ grad_norm = get_grad_norm(amp.master_params(optimizer))
250
+ else:
251
+ loss.backward()
252
+ if config.TRAIN.CLIP_GRAD:
253
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
254
+ else:
255
+ grad_norm = get_grad_norm(model.parameters())
256
+ optimizer.step()
257
+ lr_scheduler.step_update(epoch * num_steps + idx)
258
+
259
+ torch.cuda.synchronize()
260
+
261
+ loss_meter.update(loss.item(), targets.size(0))
262
+ norm_meter.update(grad_norm)
263
+ batch_time.update(time.time() - end)
264
+
265
+ if model_ema is not None:
266
+ update_model_ema(config, dist.get_world_size(), model=model, model_ema=model_ema, cur_epoch=epoch, cur_iter=idx)
267
+
268
+ end = time.time()
269
+
270
+ if idx % config.PRINT_FREQ == 0:
271
+ lr = optimizer.param_groups[0]['lr']
272
+ memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
273
+ etas = batch_time.avg * (num_steps - idx)
274
+ logger.info(
275
+ f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
276
+ f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
277
+ f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
278
+ f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
279
+ f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
280
+ f'mem {memory_used:.0f}MB')
281
+ epoch_time = time.time() - start
282
+ logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
283
+
284
+
285
+ @torch.no_grad()
286
+ def validate(config, data_loader, model):
287
+ criterion = torch.nn.CrossEntropyLoss()
288
+ model.eval()
289
+
290
+ batch_time = AverageMeter()
291
+ loss_meter = AverageMeter()
292
+ acc1_meter = AverageMeter()
293
+ acc5_meter = AverageMeter()
294
+
295
+ end = time.time()
296
+ for idx, (images, target) in enumerate(data_loader):
297
+ images = images.cuda(non_blocking=True)
298
+ target = target.cuda(non_blocking=True)
299
+
300
+ # compute output
301
+ output = model(images)
302
+
303
+ # =============================== deepsup part
304
+ if type(output) is dict:
305
+ output = output['main']
306
+
307
+ # measure accuracy and record loss
308
+ loss = criterion(output, target)
309
+ acc1, acc5 = accuracy(output, target, topk=(1, 5))
310
+
311
+ acc1 = reduce_tensor(acc1)
312
+ acc5 = reduce_tensor(acc5)
313
+ loss = reduce_tensor(loss)
314
+
315
+ loss_meter.update(loss.item(), target.size(0))
316
+ acc1_meter.update(acc1.item(), target.size(0))
317
+ acc5_meter.update(acc5.item(), target.size(0))
318
+
319
+ # measure elapsed time
320
+ batch_time.update(time.time() - end)
321
+ end = time.time()
322
+
323
+ if idx % config.PRINT_FREQ == 0:
324
+ memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
325
+ logger.info(
326
+ f'Test: [{idx}/{len(data_loader)}]\t'
327
+ f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
328
+ f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
329
+ f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
330
+ f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
331
+ f'Mem {memory_used:.0f}MB')
332
+ logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
333
+ return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
334
+
335
+
336
+ @torch.no_grad()
337
+ def throughput(data_loader, model, logger):
338
+ model.eval()
339
+
340
+ for idx, (images, _) in enumerate(data_loader):
341
+ images = images.cuda(non_blocking=True)
342
+
343
+ batch_size = images.shape[0]
344
+ for i in range(50):
345
+ model(images)
346
+ torch.cuda.synchronize()
347
+ logger.info(f"throughput averaged with 30 times")
348
+ tic1 = time.time()
349
+ for i in range(30):
350
+ model(images)
351
+ torch.cuda.synchronize()
352
+ tic2 = time.time()
353
+ throughput = 30 * batch_size / (tic2 - tic1)
354
+ logger.info(f"batch_size {batch_size} throughput {throughput}")
355
+ return
356
+
357
+
358
+ import os
359
+
360
+ if __name__ == '__main__':
361
+ args, config = parse_option()
362
+
363
+ if config.AMP_OPT_LEVEL != "O0":
364
+ assert amp is not None, "amp not installed!"
365
+
366
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
367
+ rank = int(os.environ["RANK"])
368
+ world_size = int(os.environ['WORLD_SIZE'])
369
+ else:
370
+ rank = -1
371
+ world_size = -1
372
+ torch.cuda.set_device(config.LOCAL_RANK)
373
+ torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
374
+ torch.distributed.barrier()
375
+ seed = config.SEED + dist.get_rank()
376
+
377
+ torch.manual_seed(seed)
378
+ np.random.seed(seed)
379
+ cudnn.benchmark = True
380
+
381
+ if not config.EVAL_MODE:
382
+ # linear scale the learning rate according to total batch size, may not be optimal
383
+ linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 256.0
384
+ linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 256.0
385
+ linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 256.0
386
+ # gradient accumulation also need to scale the learning rate
387
+ if config.TRAIN.ACCUMULATION_STEPS > 1:
388
+ linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
389
+ linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
390
+ linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
391
+ config.defrost()
392
+ config.TRAIN.BASE_LR = linear_scaled_lr
393
+ config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
394
+ config.TRAIN.MIN_LR = linear_scaled_min_lr
395
+ config.freeze()
396
+
397
+ print('==========================================')
398
+ print('real base lr: ', config.TRAIN.BASE_LR)
399
+ print('==========================================')
400
+
401
+ os.makedirs(config.OUTPUT, exist_ok=True)
402
+
403
+ logger = create_logger(output_dir=config.OUTPUT, dist_rank=0 if torch.cuda.device_count() == 1 else dist.get_rank(), name=f"{config.MODEL.ARCH}")
404
+
405
+ if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
406
+ path = os.path.join(config.OUTPUT, "config.json")
407
+ with open(path, "w") as f:
408
+ f.write(config.dump())
409
+ logger.info(f"Full config saved to {path}")
410
+
411
+ # print config
412
+ logger.info(config.dump())
413
+
414
+ main(config)
RepVGG-main/quantization/quant_qat_train.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ import shutil
4
+ import time
5
+ import warnings
6
+ import torch.nn as nn
7
+ import torch.nn.parallel
8
+ import torch.backends.cudnn as cudnn
9
+ import torch.distributed as dist
10
+ import torch.optim
11
+ import torch.multiprocessing as mp
12
+ import torch.utils.data
13
+ import torch.utils.data.distributed
14
+ from utils import *
15
+ import torchvision.transforms as transforms
16
+ import PIL
17
+
18
+ best_acc1 = 0
19
+
20
+ IMAGENET_TRAINSET_SIZE = 1281167
21
+
22
+ parser = argparse.ArgumentParser(description='PyTorch Whole Model Quant')
23
+ parser.add_argument('data', metavar='DIR',
24
+ help='path to dataset')
25
+ parser.add_argument('-a', '--arch', metavar='ARCH', default='RepVGG-A0')
26
+ parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
27
+ help='number of data loading workers (default: 4)')
28
+ parser.add_argument('--epochs', default=8, type=int, metavar='N',
29
+ help='number of epochs for each run')
30
+ parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
31
+ help='manual epoch number (useful on restarts)')
32
+ parser.add_argument('-b', '--batch-size', default=256, type=int,
33
+ metavar='N',
34
+ help='mini-batch size (default: 256), this is the total '
35
+ 'batch size of all GPUs on the current node when '
36
+ 'using Data Parallel or Distributed Data Parallel')
37
+ parser.add_argument('--val-batch-size', default=100, type=int, metavar='V',
38
+ help='validation batch size')
39
+ parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float,
40
+ metavar='LR', help='learning rate for finetuning', dest='lr')
41
+ parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
42
+ help='momentum')
43
+ parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
44
+ metavar='W', help='weight decay (default: 1e-4)',
45
+ dest='weight_decay')
46
+ parser.add_argument('-p', '--print-freq', default=10, type=int,
47
+ metavar='N', help='print frequency (default: 10)')
48
+ parser.add_argument('--resume', default='', type=str, metavar='PATH',
49
+ help='path to latest checkpoint (default: none)')
50
+ parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
51
+ help='evaluate model on validation set')
52
+ parser.add_argument('--world-size', default=-1, type=int,
53
+ help='number of nodes for distributed training')
54
+ parser.add_argument('--rank', default=-1, type=int,
55
+ help='node rank for distributed training')
56
+ parser.add_argument('--dist-url', default='tcp://127.0.0.1:23333', type=str,
57
+ help='url used to set up distributed training')
58
+ parser.add_argument('--dist-backend', default='nccl', type=str,
59
+ help='distributed backend')
60
+ parser.add_argument('--seed', default=None, type=int,
61
+ help='seed for initializing training. ')
62
+ parser.add_argument('--gpu', default=None, type=int,
63
+ help='GPU id to use.')
64
+ parser.add_argument('--multiprocessing-distributed', action='store_true',
65
+ help='Use multi-processing distributed training to launch '
66
+ 'N processes per node, which has N GPUs. This is the '
67
+ 'fastest way to use PyTorch for either single node or '
68
+ 'multi node data parallel training')
69
+ parser.add_argument('--base-weights', default=None, type=str,
70
+ help='weights of the base model.')
71
+ parser.add_argument('--tag', default='testtest', type=str,
72
+ help='the tag for identifying the log and model files. Just a string.')
73
+ parser.add_argument('--fpfinetune', dest='fpfinetune', action='store_true',
74
+ help='full precision finetune')
75
+ parser.add_argument('--fixobserver', dest='fixobserver', action='store_true',
76
+ help='fix observer?')
77
+ parser.add_argument('--fixbn', dest='fixbn', action='store_true',
78
+ help='fix bn?')
79
+ parser.add_argument('--quantlayers', default='all', type=str, choices=['all', 'exclud_first_and_linear', 'exclud_first_and_last'],
80
+ help='the tag for identifying the log and model files. Just a string.')
81
+
82
+
83
+
84
+ def sgd_optimizer(model, lr, momentum, weight_decay):
85
+ params = []
86
+ for key, value in model.named_parameters():
87
+ if not value.requires_grad:
88
+ continue
89
+ apply_weight_decay = weight_decay
90
+ apply_lr = lr
91
+ if value.ndimension() < 2: #TODO note this
92
+ apply_weight_decay = 0
93
+ print('set weight decay=0 for {}'.format(key))
94
+ if 'bias' in key:
95
+ apply_lr = 2 * lr # Just a Caffe-style common practice. Made no difference.
96
+ params += [{'params': [value], 'lr': apply_lr, 'weight_decay': apply_weight_decay}]
97
+ optimizer = torch.optim.SGD(params, lr, momentum=momentum)
98
+ return optimizer
99
+
100
+ def main():
101
+ args = parser.parse_args()
102
+
103
+ if args.seed is not None:
104
+ random.seed(args.seed)
105
+ torch.manual_seed(args.seed)
106
+ cudnn.deterministic = True
107
+ warnings.warn('You have chosen to seed training. '
108
+ 'This will turn on the CUDNN deterministic setting, '
109
+ 'which can slow down your training considerably! '
110
+ 'You may see unexpected behavior when restarting '
111
+ 'from checkpoints.')
112
+
113
+ if args.gpu is not None:
114
+ warnings.warn('You have chosen a specific GPU. This will completely '
115
+ 'disable data parallelism.')
116
+
117
+ if args.dist_url == "env://" and args.world_size == -1:
118
+ args.world_size = int(os.environ["WORLD_SIZE"])
119
+
120
+ args.distributed = args.world_size > 1 or args.multiprocessing_distributed
121
+
122
+ ngpus_per_node = torch.cuda.device_count()
123
+ if args.multiprocessing_distributed:
124
+ # Since we have ngpus_per_node processes per node, the total world_size
125
+ # needs to be adjusted accordingly
126
+ args.world_size = ngpus_per_node * args.world_size
127
+ # Use torch.multiprocessing.spawn to launch distributed processes: the
128
+ # main_worker process function
129
+ mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
130
+ else:
131
+ # Simply call main_worker function
132
+ main_worker(args.gpu, ngpus_per_node, args)
133
+
134
+
135
+
136
+
137
+ def get_default_train_trans(args):
138
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
139
+ std=[0.229, 0.224, 0.225])
140
+ if (not hasattr(args, 'resolution')) or args.resolution == 224:
141
+ trans = transforms.Compose([
142
+ transforms.RandomResizedCrop(224),
143
+ transforms.RandomHorizontalFlip(),
144
+ transforms.ToTensor(),
145
+ normalize])
146
+ else:
147
+ raise ValueError('Not yet implemented.')
148
+ return trans
149
+
150
+
151
+ def get_default_val_trans(args):
152
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
153
+ std=[0.229, 0.224, 0.225])
154
+ if (not hasattr(args, 'resolution')) or args.resolution == 224:
155
+ trans = transforms.Compose([
156
+ transforms.Resize(256),
157
+ transforms.CenterCrop(224),
158
+ transforms.ToTensor(),
159
+ normalize])
160
+ else:
161
+ trans = transforms.Compose([
162
+ transforms.Resize(args.resolution, interpolation=PIL.Image.BILINEAR),
163
+ transforms.CenterCrop(args.resolution),
164
+ transforms.ToTensor(),
165
+ normalize,
166
+ ])
167
+ return trans
168
+
169
+ def main_worker(gpu, ngpus_per_node, args):
170
+ global best_acc1
171
+ args.gpu = gpu
172
+ log_file = 'quant_{}_exp.txt'.format(args.tag)
173
+
174
+ if args.gpu is not None:
175
+ print("Use GPU: {} for training".format(args.gpu))
176
+
177
+ if args.distributed:
178
+ if args.dist_url == "env://" and args.rank == -1:
179
+ args.rank = int(os.environ["RANK"])
180
+ if args.multiprocessing_distributed:
181
+ # For multiprocessing distributed training, rank needs to be the
182
+ # global rank among all the processes
183
+ args.rank = args.rank * ngpus_per_node + gpu
184
+ dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
185
+ world_size=args.world_size, rank=args.rank)
186
+
187
+ # 1. Build and load base model
188
+ from repvgg import get_RepVGG_func_by_name
189
+ repvgg_build_func = get_RepVGG_func_by_name(args.arch)
190
+ base_model = repvgg_build_func(deploy=True)
191
+ from tools.insert_bn import directly_insert_bn_without_init
192
+ directly_insert_bn_without_init(base_model)
193
+ if args.base_weights is not None:
194
+ load_checkpoint(base_model, args.base_weights)
195
+
196
+ # 2.
197
+ if not args.fpfinetune:
198
+ from quantization.repvgg_quantized import RepVGGWholeQuant
199
+ qat_model = RepVGGWholeQuant(repvgg_model=base_model, quantlayers=args.quantlayers)
200
+ qat_model.prepare_quant()
201
+ else:
202
+ qat_model = base_model
203
+ log_msg('===================== not QAT, just full-precision finetune ===========', log_file)
204
+
205
+ #===================================================
206
+ # From now on, the code will be very similar to ordinary training
207
+ # ===================================================
208
+
209
+ is_main = not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0)
210
+
211
+ if is_main:
212
+ for n, p in qat_model.named_parameters():
213
+ print(n, p.size())
214
+ for n, p in qat_model.named_buffers():
215
+ print(n, p.size())
216
+ log_msg('epochs {}, lr {}, weight_decay {}'.format(args.epochs, args.lr, args.weight_decay), log_file)
217
+ # You will see it now has quantization-related parameters (zero-points and scales)
218
+
219
+ if not torch.cuda.is_available():
220
+ print('using CPU, this will be slow')
221
+ elif args.distributed:
222
+ if args.gpu is not None:
223
+ torch.cuda.set_device(args.gpu)
224
+ qat_model.cuda(args.gpu)
225
+ args.batch_size = int(args.batch_size / ngpus_per_node)
226
+ args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
227
+ qat_model = torch.nn.parallel.DistributedDataParallel(qat_model, device_ids=[args.gpu])
228
+ else:
229
+ qat_model.cuda()
230
+ qat_model = torch.nn.parallel.DistributedDataParallel(qat_model)
231
+ elif args.gpu is not None:
232
+ torch.cuda.set_device(args.gpu)
233
+ qat_model = qat_model.cuda(args.gpu)
234
+ else:
235
+ # DataParallel will divide and allocate batch_size to all available GPUs
236
+ qat_model = torch.nn.DataParallel(qat_model).cuda()
237
+
238
+
239
+ criterion = nn.CrossEntropyLoss().cuda(args.gpu)
240
+ optimizer = sgd_optimizer(qat_model, args.lr, args.momentum, args.weight_decay)
241
+
242
+ warmup_epochs = 1
243
+ lr_scheduler = WarmupCosineAnnealingLR(optimizer=optimizer, T_cosine_max=args.epochs * IMAGENET_TRAINSET_SIZE // args.batch_size // ngpus_per_node,
244
+ eta_min=0, warmup=warmup_epochs * IMAGENET_TRAINSET_SIZE // args.batch_size // ngpus_per_node)
245
+
246
+
247
+ # optionally resume from a checkpoint
248
+ if args.resume:
249
+ if os.path.isfile(args.resume):
250
+ print("=> loading checkpoint '{}'".format(args.resume))
251
+ if args.gpu is None:
252
+ checkpoint = torch.load(args.resume)
253
+ else:
254
+ # Map model to be loaded to specified single gpu.
255
+ loc = 'cuda:{}'.format(args.gpu)
256
+ checkpoint = torch.load(args.resume, map_location=loc)
257
+ args.start_epoch = checkpoint['epoch']
258
+ best_acc1 = checkpoint['best_acc1']
259
+ if args.gpu is not None:
260
+ # best_acc1 may be from a checkpoint from a different GPU
261
+ best_acc1 = best_acc1.to(args.gpu)
262
+ qat_model.load_state_dict(checkpoint['state_dict'])
263
+ optimizer.load_state_dict(checkpoint['optimizer'])
264
+ lr_scheduler.load_state_dict(checkpoint['scheduler'])
265
+ print("=> loaded checkpoint '{}' (epoch {})"
266
+ .format(args.resume, checkpoint['epoch']))
267
+ else:
268
+ print("=> no checkpoint found at '{}'".format(args.resume))
269
+
270
+ cudnn.benchmark = True
271
+
272
+ # todo
273
+ train_sampler, train_loader = get_default_ImageNet_train_sampler_loader(args)
274
+ val_loader = get_default_ImageNet_val_loader(args)
275
+
276
+ if args.evaluate:
277
+ validate(val_loader, qat_model, criterion, args)
278
+ return
279
+
280
+ for epoch in range(args.start_epoch, args.epochs):
281
+ if args.distributed:
282
+ train_sampler.set_epoch(epoch)
283
+
284
+ # train for one epoch
285
+ train(train_loader, qat_model, criterion, optimizer, epoch, args, lr_scheduler, is_main=is_main)
286
+
287
+ if args.fixobserver and epoch > (3 * args.epochs // 8):
288
+ # Freeze quantizer parameters
289
+ qat_model.apply(torch.quantization.disable_observer) #TODO testing. May not be useful
290
+ log_msg('fix observer after epoch {}'.format(epoch), log_file)
291
+
292
+ if args.fixbn and epoch > (2 * args.epochs // 8): #TODO testing. May not be useful
293
+ # Freeze batch norm mean and variance estimates
294
+ qat_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
295
+ log_msg('fix bn after epoch {}'.format(epoch), log_file)
296
+
297
+ # evaluate on validation set
298
+ if is_main:
299
+ acc1 = validate(val_loader, qat_model, criterion, args)
300
+ msg = '{}, base{}, quant, epoch {}, QAT acc {}'.format(args.arch, args.base_weights, epoch, acc1)
301
+ log_msg(msg, log_file)
302
+
303
+ is_best = acc1 > best_acc1
304
+ best_acc1 = max(acc1, best_acc1)
305
+
306
+ save_checkpoint({
307
+ 'epoch': epoch + 1,
308
+ 'arch': args.arch,
309
+ 'state_dict': qat_model.state_dict(),
310
+ 'best_acc1': best_acc1,
311
+ 'optimizer' : optimizer.state_dict(),
312
+ 'scheduler': lr_scheduler.state_dict(),
313
+ }, is_best,
314
+ filename = '{}_{}.pth.tar'.format(args.arch, args.tag),
315
+ best_filename='{}_{}_best.pth.tar'.format(args.arch, args.tag))
316
+
317
+
318
+ def train(train_loader, model, criterion, optimizer, epoch, args, lr_scheduler, is_main):
319
+ batch_time = AverageMeter('Time', ':6.3f')
320
+ data_time = AverageMeter('Data', ':6.3f')
321
+ losses = AverageMeter('Loss', ':.4e')
322
+ top1 = AverageMeter('Acc@1', ':6.2f')
323
+ top5 = AverageMeter('Acc@5', ':6.2f')
324
+ progress = ProgressMeter(
325
+ len(train_loader),
326
+ [batch_time, data_time, losses, top1, top5, ],
327
+ prefix="Epoch: [{}]".format(epoch))
328
+
329
+ # switch to train mode
330
+ model.train()
331
+
332
+ end = time.time()
333
+ for i, (images, target) in enumerate(train_loader):
334
+ # measure data loading time
335
+ data_time.update(time.time() - end)
336
+
337
+ if args.gpu is not None:
338
+ images = images.cuda(args.gpu, non_blocking=True)
339
+ if torch.cuda.is_available():
340
+ target = target.cuda(args.gpu, non_blocking=True)
341
+
342
+ # compute output
343
+
344
+ output = model(images)
345
+ loss = criterion(output, target)
346
+
347
+ # measure accuracy and record loss
348
+ acc1, acc5 = accuracy(output, target, topk=(1, 5))
349
+ losses.update(loss.item(), images.size(0))
350
+ top1.update(acc1[0], images.size(0))
351
+ top5.update(acc5[0], images.size(0))
352
+
353
+ # compute gradient and do SGD step
354
+ optimizer.zero_grad()
355
+ loss.backward()
356
+ optimizer.step()
357
+
358
+ # measure elapsed time
359
+ batch_time.update(time.time() - end)
360
+ end = time.time()
361
+
362
+ if lr_scheduler is not None:
363
+ lr_scheduler.step()
364
+
365
+ if is_main and i % args.print_freq == 0:
366
+ progress.display(i)
367
+ if is_main and i % 1000 == 0 and lr_scheduler is not None:
368
+ print('cur lr: ', lr_scheduler.get_lr()[0])
369
+
370
+
371
+
372
+
373
+ def validate(val_loader, model, criterion, args):
374
+ batch_time = AverageMeter('Time', ':6.3f')
375
+ losses = AverageMeter('Loss', ':.4e')
376
+ top1 = AverageMeter('Acc@1', ':6.2f')
377
+ top5 = AverageMeter('Acc@5', ':6.2f')
378
+ progress = ProgressMeter(
379
+ len(val_loader),
380
+ [batch_time, losses, top1, top5],
381
+ prefix='Test: ')
382
+
383
+ # switch to evaluate mode
384
+ model.eval()
385
+
386
+ with torch.no_grad():
387
+ end = time.time()
388
+ for i, (images, target) in enumerate(val_loader):
389
+ images = images.cuda(args.gpu, non_blocking=True)
390
+ target = target.cuda(args.gpu, non_blocking=True)
391
+
392
+ # compute output
393
+ output = model(images)
394
+ loss = criterion(output, target)
395
+
396
+ # measure accuracy and record loss
397
+ acc1, acc5 = accuracy(output, target, topk=(1, 5))
398
+ losses.update(loss.item(), images.size(0))
399
+ top1.update(acc1[0], images.size(0))
400
+ top5.update(acc5[0], images.size(0))
401
+
402
+ # measure elapsed time
403
+ batch_time.update(time.time() - end)
404
+ end = time.time()
405
+
406
+ if i % args.print_freq == 0:
407
+ progress.display(i)
408
+
409
+ # TODO: this should also be done with the ProgressMeter
410
+ print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
411
+ .format(top1=top1, top5=top5))
412
+
413
+ return top1.avg
414
+
415
+
416
+ def save_checkpoint(state, is_best, filename, best_filename):
417
+ torch.save(state, filename)
418
+ if is_best:
419
+ shutil.copyfile(filename, best_filename)
420
+
421
+
422
+
423
+
424
+
425
+ if __name__ == '__main__':
426
+ main()
RepVGG-main/quantization/repvgg_quantized.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.quantization import QuantStub, DeQuantStub
4
+
5
+ class RepVGGWholeQuant(nn.Module):
6
+
7
+ def __init__(self, repvgg_model, quantlayers):
8
+ super(RepVGGWholeQuant, self).__init__()
9
+ assert quantlayers in ['all', 'exclud_first_and_linear', 'exclud_first_and_last']
10
+ self.quantlayers = quantlayers
11
+ self.quant = QuantStub()
12
+ self.stage0, self.stage1, self.stage2, self.stage3, self.stage4 = repvgg_model.stage0, repvgg_model.stage1, repvgg_model.stage2, repvgg_model.stage3, repvgg_model.stage4
13
+ self.gap, self.linear = repvgg_model.gap, repvgg_model.linear
14
+ self.dequant = DeQuantStub()
15
+
16
+
17
+ def forward(self, x):
18
+ if self.quantlayers == 'all':
19
+ x = self.quant(x)
20
+ out = self.stage0(x)
21
+ else:
22
+ out = self.stage0(x)
23
+ out = self.quant(out)
24
+ out = self.stage1(out)
25
+ out = self.stage2(out)
26
+ out = self.stage3(out)
27
+ if self.quantlayers == 'all':
28
+ out = self.stage4(out)
29
+ out = self.gap(out).view(out.size(0), -1)
30
+ out = self.linear(out)
31
+ out = self.dequant(out)
32
+ elif self.quantlayers == 'exclud_first_and_linear':
33
+ out = self.stage4(out)
34
+ out = self.dequant(out)
35
+ out = self.gap(out).view(out.size(0), -1)
36
+ out = self.linear(out)
37
+ else:
38
+ out = self.dequant(out)
39
+ out = self.stage4(out)
40
+ out = self.gap(out).view(out.size(0), -1)
41
+ out = self.linear(out)
42
+ return out
43
+
44
+ # From https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html
45
+ def fuse_model(self):
46
+ for m in self.modules():
47
+ if type(m) == nn.Sequential and hasattr(m, 'conv'):
48
+ # Note that we moved ReLU from "block.nonlinearity" into "rbr_reparam" (nn.Sequential).
49
+ # This makes it more convenient to fuse operators using off-the-shelf APIs.
50
+ torch.quantization.fuse_modules(m, ['conv', 'bn', 'relu'], inplace=True)
51
+
52
+ def _get_qconfig(self):
53
+ return torch.quantization.get_default_qat_qconfig('fbgemm')
54
+
55
+ def prepare_quant(self):
56
+ # From https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html
57
+ self.fuse_model()
58
+ qconfig = self._get_qconfig()
59
+ self.qconfig = qconfig
60
+ torch.quantization.prepare_qat(self, inplace=True)
61
+
62
+ def freeze_quant_bn(self):
63
+ self.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
RepVGG-main/repvgg.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf)
3
+ # Github source: https://github.com/DingXiaoH/RepVGG
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ import torch.nn as nn
7
+ import numpy as np
8
+ import torch
9
+ import copy
10
+ from se_block import SEBlock
11
+ import torch.utils.checkpoint as checkpoint
12
+
13
+ def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
14
+ result = nn.Sequential()
15
+ result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
16
+ kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False))
17
+ result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
18
+ return result
19
+
20
+ class RepVGGBlock(nn.Module):
21
+
22
+ def __init__(self, in_channels, out_channels, kernel_size,
23
+ stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False):
24
+ super(RepVGGBlock, self).__init__()
25
+ self.deploy = deploy
26
+ self.groups = groups
27
+ self.in_channels = in_channels
28
+
29
+ assert kernel_size == 3
30
+ assert padding == 1
31
+
32
+ padding_11 = padding - kernel_size // 2
33
+
34
+ self.nonlinearity = nn.ReLU()
35
+
36
+ if use_se:
37
+ # Note that RepVGG-D2se uses SE before nonlinearity. But RepVGGplus models uses SE after nonlinearity.
38
+ self.se = SEBlock(out_channels, internal_neurons=out_channels // 16)
39
+ else:
40
+ self.se = nn.Identity()
41
+
42
+ if deploy:
43
+ self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
44
+ padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
45
+
46
+ else:
47
+ self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None
48
+ self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
49
+ self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups)
50
+ print('RepVGG Block, identity = ', self.rbr_identity)
51
+
52
+
53
+ def forward(self, inputs):
54
+ if hasattr(self, 'rbr_reparam'):
55
+ return self.nonlinearity(self.se(self.rbr_reparam(inputs)))
56
+
57
+ if self.rbr_identity is None:
58
+ id_out = 0
59
+ else:
60
+ id_out = self.rbr_identity(inputs)
61
+
62
+ return self.nonlinearity(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))
63
+
64
+
65
+ # Optional. This may improve the accuracy and facilitates quantization in some cases.
66
+ # 1. Cancel the original weight decay on rbr_dense.conv.weight and rbr_1x1.conv.weight.
67
+ # 2. Use like this.
68
+ # loss = criterion(....)
69
+ # for every RepVGGBlock blk:
70
+ # loss += weight_decay_coefficient * 0.5 * blk.get_cust_L2()
71
+ # optimizer.zero_grad()
72
+ # loss.backward()
73
+ def get_custom_L2(self):
74
+ K3 = self.rbr_dense.conv.weight
75
+ K1 = self.rbr_1x1.conv.weight
76
+ t3 = (self.rbr_dense.bn.weight / ((self.rbr_dense.bn.running_var + self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
77
+ t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var + self.rbr_1x1.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
78
+
79
+ l2_loss_circle = (K3 ** 2).sum() - (K3[:, :, 1:2, 1:2] ** 2).sum() # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them.
80
+ eq_kernel = K3[:, :, 1:2, 1:2] * t3 + K1 * t1 # The equivalent resultant central point of 3x3 kernel.
81
+ l2_loss_eq_kernel = (eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)).sum() # Normalize for an L2 coefficient comparable to regular L2.
82
+ return l2_loss_eq_kernel + l2_loss_circle
83
+
84
+
85
+
86
+ # This func derives the equivalent kernel and bias in a DIFFERENTIABLE way.
87
+ # You can get the equivalent kernel and bias at any time and do whatever you want,
88
+ # for example, apply some penalties or constraints during training, just like you do to the other models.
89
+ # May be useful for quantization or pruning.
90
+ def get_equivalent_kernel_bias(self):
91
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
92
+ kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
93
+ kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
94
+ return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
95
+
96
+ def _pad_1x1_to_3x3_tensor(self, kernel1x1):
97
+ if kernel1x1 is None:
98
+ return 0
99
+ else:
100
+ return torch.nn.functional.pad(kernel1x1, [1,1,1,1])
101
+
102
+ def _fuse_bn_tensor(self, branch):
103
+ if branch is None:
104
+ return 0, 0
105
+ if isinstance(branch, nn.Sequential):
106
+ kernel = branch.conv.weight
107
+ running_mean = branch.bn.running_mean
108
+ running_var = branch.bn.running_var
109
+ gamma = branch.bn.weight
110
+ beta = branch.bn.bias
111
+ eps = branch.bn.eps
112
+ else:
113
+ assert isinstance(branch, nn.BatchNorm2d)
114
+ if not hasattr(self, 'id_tensor'):
115
+ input_dim = self.in_channels // self.groups
116
+ kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
117
+ for i in range(self.in_channels):
118
+ kernel_value[i, i % input_dim, 1, 1] = 1
119
+ self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
120
+ kernel = self.id_tensor
121
+ running_mean = branch.running_mean
122
+ running_var = branch.running_var
123
+ gamma = branch.weight
124
+ beta = branch.bias
125
+ eps = branch.eps
126
+ std = (running_var + eps).sqrt()
127
+ t = (gamma / std).reshape(-1, 1, 1, 1)
128
+ return kernel * t, beta - running_mean * gamma / std
129
+
130
+ def switch_to_deploy(self):
131
+ if hasattr(self, 'rbr_reparam'):
132
+ return
133
+ kernel, bias = self.get_equivalent_kernel_bias()
134
+ self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels, out_channels=self.rbr_dense.conv.out_channels,
135
+ kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride,
136
+ padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, groups=self.rbr_dense.conv.groups, bias=True)
137
+ self.rbr_reparam.weight.data = kernel
138
+ self.rbr_reparam.bias.data = bias
139
+ self.__delattr__('rbr_dense')
140
+ self.__delattr__('rbr_1x1')
141
+ if hasattr(self, 'rbr_identity'):
142
+ self.__delattr__('rbr_identity')
143
+ if hasattr(self, 'id_tensor'):
144
+ self.__delattr__('id_tensor')
145
+ self.deploy = True
146
+
147
+
148
+
149
+ class RepVGG(nn.Module):
150
+
151
+ def __init__(self, num_blocks, num_classes=1000, width_multiplier=None, override_groups_map=None, deploy=False, use_se=False, use_checkpoint=False):
152
+ super(RepVGG, self).__init__()
153
+ assert len(width_multiplier) == 4
154
+ self.deploy = deploy
155
+ self.override_groups_map = override_groups_map or dict()
156
+ assert 0 not in self.override_groups_map
157
+ self.use_se = use_se
158
+ self.use_checkpoint = use_checkpoint
159
+
160
+ self.in_planes = min(64, int(64 * width_multiplier[0]))
161
+ self.stage0 = RepVGGBlock(in_channels=3, out_channels=self.in_planes, kernel_size=3, stride=2, padding=1, deploy=self.deploy, use_se=self.use_se)
162
+ self.cur_layer_idx = 1
163
+ self.stage1 = self._make_stage(int(64 * width_multiplier[0]), num_blocks[0], stride=2)
164
+ self.stage2 = self._make_stage(int(128 * width_multiplier[1]), num_blocks[1], stride=2)
165
+ self.stage3 = self._make_stage(int(256 * width_multiplier[2]), num_blocks[2], stride=2)
166
+ self.stage4 = self._make_stage(int(512 * width_multiplier[3]), num_blocks[3], stride=2)
167
+ self.gap = nn.AdaptiveAvgPool2d(output_size=1)
168
+ self.linear = nn.Linear(int(512 * width_multiplier[3]), num_classes)
169
+
170
+ def _make_stage(self, planes, num_blocks, stride):
171
+ strides = [stride] + [1]*(num_blocks-1)
172
+ blocks = []
173
+ for stride in strides:
174
+ cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1)
175
+ blocks.append(RepVGGBlock(in_channels=self.in_planes, out_channels=planes, kernel_size=3,
176
+ stride=stride, padding=1, groups=cur_groups, deploy=self.deploy, use_se=self.use_se))
177
+ self.in_planes = planes
178
+ self.cur_layer_idx += 1
179
+ return nn.ModuleList(blocks)
180
+
181
+ def forward(self, x):
182
+ out = self.stage0(x)
183
+ for stage in (self.stage1, self.stage2, self.stage3, self.stage4):
184
+ for block in stage:
185
+ if self.use_checkpoint:
186
+ out = checkpoint.checkpoint(block, out)
187
+ else:
188
+ out = block(out)
189
+ out = self.gap(out)
190
+ out = out.view(out.size(0), -1)
191
+ out = self.linear(out)
192
+ return out
193
+
194
+
195
+ optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
196
+ g2_map = {l: 2 for l in optional_groupwise_layers}
197
+ g4_map = {l: 4 for l in optional_groupwise_layers}
198
+
199
+ def create_RepVGG_A0(deploy=False, use_checkpoint=False):
200
+ return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
201
+ width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
202
+
203
+ def create_RepVGG_A1(deploy=False, use_checkpoint=False):
204
+ return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
205
+ width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
206
+
207
+ def create_RepVGG_A2(deploy=False, use_checkpoint=False):
208
+ return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,
209
+ width_multiplier=[1.5, 1.5, 1.5, 2.75], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
210
+
211
+ def create_RepVGG_B0(deploy=False, use_checkpoint=False):
212
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
213
+ width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
214
+
215
+ def create_RepVGG_B1(deploy=False, use_checkpoint=False):
216
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
217
+ width_multiplier=[2, 2, 2, 4], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
218
+
219
+ def create_RepVGG_B1g2(deploy=False, use_checkpoint=False):
220
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
221
+ width_multiplier=[2, 2, 2, 4], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint)
222
+
223
+ def create_RepVGG_B1g4(deploy=False, use_checkpoint=False):
224
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
225
+ width_multiplier=[2, 2, 2, 4], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint)
226
+
227
+
228
+ def create_RepVGG_B2(deploy=False, use_checkpoint=False):
229
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
230
+ width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
231
+
232
+ def create_RepVGG_B2g2(deploy=False, use_checkpoint=False):
233
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
234
+ width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint)
235
+
236
+ def create_RepVGG_B2g4(deploy=False, use_checkpoint=False):
237
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
238
+ width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint)
239
+
240
+
241
+ def create_RepVGG_B3(deploy=False, use_checkpoint=False):
242
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
243
+ width_multiplier=[3, 3, 3, 5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)
244
+
245
+ def create_RepVGG_B3g2(deploy=False, use_checkpoint=False):
246
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
247
+ width_multiplier=[3, 3, 3, 5], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint)
248
+
249
+ def create_RepVGG_B3g4(deploy=False, use_checkpoint=False):
250
+ return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,
251
+ width_multiplier=[3, 3, 3, 5], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint)
252
+
253
+ def create_RepVGG_D2se(deploy=False, use_checkpoint=False):
254
+ return RepVGG(num_blocks=[8, 14, 24, 1], num_classes=1000,
255
+ width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_se=True, use_checkpoint=use_checkpoint)
256
+
257
+
258
+ func_dict = {
259
+ 'RepVGG-A0': create_RepVGG_A0,
260
+ 'RepVGG-A1': create_RepVGG_A1,
261
+ 'RepVGG-A2': create_RepVGG_A2,
262
+ 'RepVGG-B0': create_RepVGG_B0,
263
+ 'RepVGG-B1': create_RepVGG_B1,
264
+ 'RepVGG-B1g2': create_RepVGG_B1g2,
265
+ 'RepVGG-B1g4': create_RepVGG_B1g4,
266
+ 'RepVGG-B2': create_RepVGG_B2,
267
+ 'RepVGG-B2g2': create_RepVGG_B2g2,
268
+ 'RepVGG-B2g4': create_RepVGG_B2g4,
269
+ 'RepVGG-B3': create_RepVGG_B3,
270
+ 'RepVGG-B3g2': create_RepVGG_B3g2,
271
+ 'RepVGG-B3g4': create_RepVGG_B3g4,
272
+ 'RepVGG-D2se': create_RepVGG_D2se, # Updated at April 25, 2021. This is not reported in the CVPR paper.
273
+ }
274
+ def get_RepVGG_func_by_name(name):
275
+ return func_dict[name]
276
+
277
+
278
+
279
+ # Use this for converting a RepVGG model or a bigger model with RepVGG as its component
280
+ # Use like this
281
+ # model = create_RepVGG_A0(deploy=False)
282
+ # train model or load weights
283
+ # repvgg_model_convert(model, save_path='repvgg_deploy.pth')
284
+ # If you want to preserve the original model, call with do_copy=True
285
+
286
+ # ====================== for using RepVGG as the backbone of a bigger model, e.g., PSPNet, the pseudo code will be like
287
+ # train_backbone = create_RepVGG_B2(deploy=False)
288
+ # train_backbone.load_state_dict(torch.load('RepVGG-B2-train.pth'))
289
+ # train_pspnet = build_pspnet(backbone=train_backbone)
290
+ # segmentation_train(train_pspnet)
291
+ # deploy_pspnet = repvgg_model_convert(train_pspnet)
292
+ # segmentation_test(deploy_pspnet)
293
+ # ===================== example_pspnet.py shows an example
294
+
295
+ def repvgg_model_convert(model:torch.nn.Module, save_path=None, do_copy=True):
296
+ if do_copy:
297
+ model = copy.deepcopy(model)
298
+ for module in model.modules():
299
+ if hasattr(module, 'switch_to_deploy'):
300
+ module.switch_to_deploy()
301
+ if save_path is not None:
302
+ torch.save(model.state_dict(), save_path)
303
+ return model
RepVGG-main/repvggplus.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf)
3
+ # Github source: https://github.com/DingXiaoH/RepVGG
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import torch.nn as nn
8
+ import torch.utils.checkpoint as checkpoint
9
+ from se_block import SEBlock
10
+ import torch
11
+ import numpy as np
12
+
13
+ def conv_bn_relu(in_channels, out_channels, kernel_size, stride, padding, groups=1):
14
+ result = nn.Sequential()
15
+ result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
16
+ kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False))
17
+ result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
18
+ result.add_module('relu', nn.ReLU())
19
+ return result
20
+
21
+ def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
22
+ result = nn.Sequential()
23
+ result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
24
+ kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False))
25
+ result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
26
+ return result
27
+
28
+ class RepVGGplusBlock(nn.Module):
29
+
30
+ def __init__(self, in_channels, out_channels, kernel_size,
31
+ stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros',
32
+ deploy=False,
33
+ use_post_se=False):
34
+ super(RepVGGplusBlock, self).__init__()
35
+ self.deploy = deploy
36
+ self.groups = groups
37
+ self.in_channels = in_channels
38
+
39
+ assert kernel_size == 3
40
+ assert padding == 1
41
+
42
+ self.nonlinearity = nn.ReLU()
43
+
44
+ if use_post_se:
45
+ self.post_se = SEBlock(out_channels, internal_neurons=out_channels // 4)
46
+ else:
47
+ self.post_se = nn.Identity()
48
+
49
+ if deploy:
50
+ self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
51
+ padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
52
+ else:
53
+ if out_channels == in_channels and stride == 1:
54
+ self.rbr_identity = nn.BatchNorm2d(num_features=out_channels)
55
+ else:
56
+ self.rbr_identity = None
57
+ self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
58
+ padding_11 = padding - kernel_size // 2
59
+ self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups)
60
+
61
+ def forward(self, x):
62
+ if self.deploy:
63
+ return self.post_se(self.nonlinearity(self.rbr_reparam(x)))
64
+
65
+ if self.rbr_identity is None:
66
+ id_out = 0
67
+ else:
68
+ id_out = self.rbr_identity(x)
69
+ out = self.rbr_dense(x) + self.rbr_1x1(x) + id_out
70
+ out = self.post_se(self.nonlinearity(out))
71
+ return out
72
+
73
+
74
+ # This func derives the equivalent kernel and bias in a DIFFERENTIABLE way.
75
+ # You can get the equivalent kernel and bias at any time and do whatever you want,
76
+ # for example, apply some penalties or constraints during training, just like you do to the other models.
77
+ # May be useful for quantization or pruning.
78
+ def get_equivalent_kernel_bias(self):
79
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
80
+ kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
81
+ kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
82
+ return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
83
+
84
+ def _pad_1x1_to_3x3_tensor(self, kernel1x1):
85
+ if kernel1x1 is None:
86
+ return 0
87
+ else:
88
+ return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
89
+
90
+ def _fuse_bn_tensor(self, branch):
91
+ if branch is None:
92
+ return 0, 0
93
+ if isinstance(branch, nn.Sequential):
94
+ # For the 1x1 or 3x3 branch
95
+ kernel, running_mean, running_var, gamma, beta, eps = branch.conv.weight, branch.bn.running_mean, branch.bn.running_var, branch.bn.weight, branch.bn.bias, branch.bn.eps
96
+ else:
97
+ # For the identity branch
98
+ assert isinstance(branch, nn.BatchNorm2d)
99
+ if not hasattr(self, 'id_tensor'):
100
+ # Construct and store the identity kernel in case it is used multiple times
101
+ input_dim = self.in_channels // self.groups
102
+ kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
103
+ for i in range(self.in_channels):
104
+ kernel_value[i, i % input_dim, 1, 1] = 1
105
+ self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
106
+ kernel, running_mean, running_var, gamma, beta, eps = self.id_tensor, branch.running_mean, branch.running_var, branch.weight, branch.bias, branch.eps
107
+ std = (running_var + eps).sqrt()
108
+ t = (gamma / std).reshape(-1, 1, 1, 1)
109
+ return kernel * t, beta - running_mean * gamma / std
110
+
111
+ def switch_to_deploy(self):
112
+ if hasattr(self, 'rbr_reparam'):
113
+ return
114
+ kernel, bias = self.get_equivalent_kernel_bias()
115
+ self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels,
116
+ out_channels=self.rbr_dense.conv.out_channels,
117
+ kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride,
118
+ padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation,
119
+ groups=self.rbr_dense.conv.groups, bias=True)
120
+ self.rbr_reparam.weight.data = kernel
121
+ self.rbr_reparam.bias.data = bias
122
+ self.__delattr__('rbr_dense')
123
+ self.__delattr__('rbr_1x1')
124
+ if hasattr(self, 'rbr_identity'):
125
+ self.__delattr__('rbr_identity')
126
+ if hasattr(self, 'id_tensor'):
127
+ self.__delattr__('id_tensor')
128
+ self.deploy = True
129
+
130
+
131
+
132
+ class RepVGGplusStage(nn.Module):
133
+
134
+ def __init__(self, in_planes, planes, num_blocks, stride, use_checkpoint, use_post_se=False, deploy=False):
135
+ super().__init__()
136
+ strides = [stride] + [1] * (num_blocks - 1)
137
+ blocks = []
138
+ self.in_planes = in_planes
139
+ for stride in strides:
140
+ cur_groups = 1
141
+ blocks.append(RepVGGplusBlock(in_channels=self.in_planes, out_channels=planes, kernel_size=3,
142
+ stride=stride, padding=1, groups=cur_groups, deploy=deploy, use_post_se=use_post_se))
143
+ self.in_planes = planes
144
+ self.blocks = nn.ModuleList(blocks)
145
+ self.use_checkpoint = use_checkpoint
146
+
147
+ def forward(self, x):
148
+ for block in self.blocks:
149
+ if self.use_checkpoint:
150
+ x = checkpoint.checkpoint(block, x)
151
+ else:
152
+ x = block(x)
153
+ return x
154
+
155
+
156
+ class RepVGGplus(nn.Module):
157
+ """RepVGGplus
158
+ An official improved version of RepVGG (RepVGG: Making VGG-style ConvNets Great Again) <https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf>`_.
159
+
160
+ Args:
161
+ num_blocks (tuple[int]): Depths of each stage.
162
+ num_classes (tuple[int]): Num of classes.
163
+ width_multiplier (tuple[float]): The width of the four stages
164
+ will be (64 * width_multiplier[0], 128 * width_multiplier[1], 256 * width_multiplier[2], 512 * width_multiplier[3]).
165
+ deploy (bool, optional): If True, the model will have the inference-time structure.
166
+ Default: False.
167
+ use_post_se (bool, optional): If True, the model will have Squeeze-and-Excitation blocks following the conv-ReLU units.
168
+ Default: False.
169
+ use_checkpoint (bool, optional): If True, the model will use torch.utils.checkpoint to save the GPU memory during training with acceptable slowdown.
170
+ Do not use it if you have sufficient GPU memory.
171
+ Default: False.
172
+ """
173
+ def __init__(self,
174
+ num_blocks,
175
+ num_classes,
176
+ width_multiplier,
177
+ deploy=False,
178
+ use_post_se=False,
179
+ use_checkpoint=False):
180
+ super().__init__()
181
+
182
+ self.deploy = deploy
183
+ self.num_classes = num_classes
184
+
185
+ in_channels = min(64, int(64 * width_multiplier[0]))
186
+ stage_channels = [int(64 * width_multiplier[0]), int(128 * width_multiplier[1]), int(256 * width_multiplier[2]), int(512 * width_multiplier[3])]
187
+ self.stage0 = RepVGGplusBlock(in_channels=3, out_channels=in_channels, kernel_size=3, stride=2, padding=1, deploy=self.deploy, use_post_se=use_post_se)
188
+ self.stage1 = RepVGGplusStage(in_channels, stage_channels[0], num_blocks[0], stride=2, use_checkpoint=use_checkpoint, use_post_se=use_post_se, deploy=deploy)
189
+ self.stage2 = RepVGGplusStage(stage_channels[0], stage_channels[1], num_blocks[1], stride=2, use_checkpoint=use_checkpoint, use_post_se=use_post_se, deploy=deploy)
190
+ # split stage3 so that we can insert an auxiliary classifier
191
+ self.stage3_first = RepVGGplusStage(stage_channels[1], stage_channels[2], num_blocks[2] // 2, stride=2, use_checkpoint=use_checkpoint, use_post_se=use_post_se, deploy=deploy)
192
+ self.stage3_second = RepVGGplusStage(stage_channels[2], stage_channels[2], num_blocks[2] - num_blocks[2] // 2, stride=1, use_checkpoint=use_checkpoint, use_post_se=use_post_se, deploy=deploy)
193
+ self.stage4 = RepVGGplusStage(stage_channels[2], stage_channels[3], num_blocks[3], stride=2, use_checkpoint=use_checkpoint, use_post_se=use_post_se, deploy=deploy)
194
+ self.gap = nn.AdaptiveAvgPool2d(output_size=1)
195
+ self.flatten = nn.Flatten()
196
+ self.linear = nn.Linear(int(512 * width_multiplier[3]), num_classes)
197
+ # aux classifiers
198
+ if not self.deploy:
199
+ self.stage1_aux = self._build_aux_for_stage(self.stage1)
200
+ self.stage2_aux = self._build_aux_for_stage(self.stage2)
201
+ self.stage3_first_aux = self._build_aux_for_stage(self.stage3_first)
202
+
203
+ def _build_aux_for_stage(self, stage):
204
+ stage_out_channels = list(stage.blocks.children())[-1].rbr_dense.conv.out_channels
205
+ downsample = conv_bn_relu(in_channels=stage_out_channels, out_channels=stage_out_channels, kernel_size=3, stride=2, padding=1)
206
+ fc = nn.Linear(stage_out_channels, self.num_classes, bias=True)
207
+ return nn.Sequential(downsample, nn.AdaptiveAvgPool2d(1), nn.Flatten(), fc)
208
+
209
+ def forward(self, x):
210
+ out = self.stage0(x)
211
+ out = self.stage1(out)
212
+ stage1_aux = self.stage1_aux(out)
213
+ out = self.stage2(out)
214
+ stage2_aux = self.stage2_aux(out)
215
+ out = self.stage3_first(out)
216
+ stage3_first_aux = self.stage3_first_aux(out)
217
+ out = self.stage3_second(out)
218
+ out = self.stage4(out)
219
+ y = self.gap(out)
220
+ y = self.flatten(y)
221
+ y = self.linear(y)
222
+ return {
223
+ 'main': y,
224
+ 'stage1_aux': stage1_aux,
225
+ 'stage2_aux': stage2_aux,
226
+ 'stage3_first_aux': stage3_first_aux,
227
+ }
228
+
229
+ def switch_repvggplus_to_deploy(self):
230
+ for m in self.modules():
231
+ if hasattr(m, 'switch_to_deploy'):
232
+ m.switch_to_deploy()
233
+ if hasattr(self, 'stage1_aux'):
234
+ self.__delattr__('stage1_aux')
235
+ if hasattr(self, 'stage2_aux'):
236
+ self.__delattr__('stage2_aux')
237
+ if hasattr(self, 'stage3_first_aux'):
238
+ self.__delattr__('stage3_first_aux')
239
+ self.deploy = True
240
+
241
+
242
+ # torch.utils.checkpoint can reduce the memory consumption during training with a minor slowdown. Don't use it if you have sufficient GPU memory.
243
+ # Not sure whether it slows down inference
244
+ # pse for "post SE", which means using SE block after ReLU
245
+ def create_RepVGGplus_L2pse(deploy=False, use_checkpoint=False):
246
+ return RepVGGplus(num_blocks=[8, 14, 24, 1], num_classes=1000,
247
+ width_multiplier=[2.5, 2.5, 2.5, 5], deploy=deploy, use_post_se=True,
248
+ use_checkpoint=use_checkpoint)
249
+
250
+ # Will release more
251
+ repvggplus_func_dict = {
252
+ 'RepVGGplus-L2pse': create_RepVGGplus_L2pse,
253
+ }
254
+
255
+ def create_RepVGGplus_by_name(name, deploy=False, use_checkpoint=False):
256
+ if 'plus' in name:
257
+ return repvggplus_func_dict[name](deploy=deploy, use_checkpoint=use_checkpoint)
258
+ else:
259
+ print('=================== Building the vanila RepVGG ===================')
260
+ from repvgg import get_RepVGG_func_by_name
261
+ return get_RepVGG_func_by_name(name)(deploy=deploy, use_checkpoint=use_checkpoint)
262
+
263
+
264
+
265
+
266
+
267
+
268
+ # Use this for converting a RepVGG model or a bigger model with RepVGG as its component
269
+ # Use like this
270
+ # model = create_RepVGG_A0(deploy=False)
271
+ # train model or load weights
272
+ # repvgg_model_convert(model, save_path='repvgg_deploy.pth')
273
+ # If you want to preserve the original model, call with do_copy=True
274
+
275
+ # ====================== for using RepVGG as the backbone of a bigger model, e.g., PSPNet, the pseudo code will be like
276
+ # train_backbone = create_RepVGG_B2(deploy=False)
277
+ # train_backbone.load_state_dict(torch.load('RepVGG-B2-train.pth'))
278
+ # train_pspnet = build_pspnet(backbone=train_backbone)
279
+ # segmentation_train(train_pspnet)
280
+ # deploy_pspnet = repvgg_model_convert(train_pspnet)
281
+ # segmentation_test(deploy_pspnet)
282
+ # ===================== example_pspnet.py shows an example
283
+
284
+ def repvgg_model_convert(model:torch.nn.Module, save_path=None, do_copy=True):
285
+ import copy
286
+ if do_copy:
287
+ model = copy.deepcopy(model)
288
+ for module in model.modules():
289
+ if hasattr(module, 'switch_to_deploy'):
290
+ module.switch_to_deploy()
291
+ if save_path is not None:
292
+ torch.save(model.state_dict(), save_path)
293
+ return model
RepVGG-main/repvggplus_custom_L2.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf)
3
+ # Github source: https://github.com/DingXiaoH/RepVGG
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ import torch.nn as nn
7
+ import torch.utils.checkpoint as checkpoint
8
+ from se_block import SEBlock
9
+ import torch
10
+ import numpy as np
11
+
12
+
13
+ def conv_bn_relu(in_channels, out_channels, kernel_size, stride, padding, groups=1):
14
+ result = nn.Sequential()
15
+ result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
16
+ kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False))
17
+ result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
18
+ result.add_module('relu', nn.ReLU())
19
+ return result
20
+
21
+ def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
22
+ result = nn.Sequential()
23
+ result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
24
+ kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False))
25
+ result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
26
+ return result
27
+
28
+ class RepVGGplusBlock(nn.Module):
29
+
30
+ def __init__(self, in_channels, out_channels, kernel_size,
31
+ stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros',
32
+ deploy=False,
33
+ use_post_se=False):
34
+ super(RepVGGplusBlock, self).__init__()
35
+ self.deploy = deploy
36
+ self.groups = groups
37
+ self.in_channels = in_channels
38
+
39
+ assert kernel_size == 3
40
+ assert padding == 1
41
+
42
+ self.nonlinearity = nn.ReLU()
43
+
44
+ if use_post_se:
45
+ self.post_se = SEBlock(out_channels, internal_neurons=out_channels // 4)
46
+ else:
47
+ self.post_se = nn.Identity()
48
+
49
+ if deploy:
50
+ self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
51
+ padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
52
+ else:
53
+ if out_channels == in_channels and stride == 1:
54
+ self.rbr_identity = nn.BatchNorm2d(num_features=out_channels)
55
+ else:
56
+ self.rbr_identity = None
57
+ self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
58
+ padding_11 = padding - kernel_size // 2
59
+ self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups)
60
+
61
+ def forward(self, x, L2):
62
+
63
+ if self.deploy:
64
+ return self.post_se(self.nonlinearity(self.rbr_reparam(x))), None
65
+
66
+ if self.rbr_identity is None:
67
+ id_out = 0
68
+ else:
69
+ id_out = self.rbr_identity(x)
70
+ out = self.rbr_dense(x) + self.rbr_1x1(x) + id_out
71
+ out = self.post_se(self.nonlinearity(out))
72
+
73
+ # Custom L2
74
+ t3 = (self.rbr_dense.bn.weight / ((self.rbr_dense.bn.running_var + self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
75
+ t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var + self.rbr_1x1.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
76
+ K3 = self.rbr_dense.conv.weight
77
+ K1 = self.rbr_1x1.conv.weight
78
+
79
+ l2_loss_circle = (K3 ** 2).sum() - (K3[:, :, 1:2, 1:2] ** 2).sum()
80
+ eq_kernel = K3[:,:,1:2,1:2] * t3 + K1 * t1
81
+ l2_loss_eq_kernel = (eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)).sum()
82
+
83
+ return out, L2 + l2_loss_circle + l2_loss_eq_kernel
84
+
85
+
86
+ # This func derives the equivalent kernel and bias in a DIFFERENTIABLE way.
87
+ # You can get the equivalent kernel and bias at any time and do whatever you want,
88
+ # for example, apply some penalties or constraints during training, just like you do to the other models.
89
+ # May be useful for quantization or pruning.
90
+ def get_equivalent_kernel_bias(self):
91
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
92
+ kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
93
+ kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
94
+ return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
95
+
96
+ def _pad_1x1_to_3x3_tensor(self, kernel1x1):
97
+ if kernel1x1 is None:
98
+ return 0
99
+ else:
100
+ return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
101
+
102
+ def _fuse_bn_tensor(self, branch):
103
+ if branch is None:
104
+ return 0, 0
105
+ if isinstance(branch, nn.Sequential):
106
+ # For the 1x1 or 3x3 branch
107
+ kernel, running_mean, running_var, gamma, beta, eps = branch.conv.weight, branch.bn.running_mean, branch.bn.running_var, branch.bn.weight, branch.bn.bias, branch.bn.eps
108
+ else:
109
+ # For the identity branch
110
+ assert isinstance(branch, nn.BatchNorm2d)
111
+ if not hasattr(self, 'id_tensor'):
112
+ # Construct and store the identity kernel in case it is used multiple times
113
+ input_dim = self.in_channels // self.groups
114
+ kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
115
+ for i in range(self.in_channels):
116
+ kernel_value[i, i % input_dim, 1, 1] = 1
117
+ self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
118
+ kernel, running_mean, running_var, gamma, beta, eps = self.id_tensor, branch.running_mean, branch.running_var, branch.weight, branch.bias, branch.eps
119
+ std = (running_var + eps).sqrt()
120
+ t = (gamma / std).reshape(-1, 1, 1, 1)
121
+ return kernel * t, beta - running_mean * gamma / std
122
+
123
+ def switch_to_deploy(self):
124
+ if hasattr(self, 'rbr_reparam'):
125
+ return
126
+ kernel, bias = self.get_equivalent_kernel_bias()
127
+ self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels,
128
+ out_channels=self.rbr_dense.conv.out_channels,
129
+ kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride,
130
+ padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation,
131
+ groups=self.rbr_dense.conv.groups, bias=True)
132
+ self.rbr_reparam.weight.data = kernel
133
+ self.rbr_reparam.bias.data = bias
134
+ self.__delattr__('rbr_dense')
135
+ self.__delattr__('rbr_1x1')
136
+ if hasattr(self, 'rbr_identity'):
137
+ self.__delattr__('rbr_identity')
138
+ if hasattr(self, 'id_tensor'):
139
+ self.__delattr__('id_tensor')
140
+ self.deploy = True
141
+
142
+
143
+
144
+ class RepVGGplusStage(nn.Module):
145
+
146
+ def __init__(self, in_planes, planes, num_blocks, stride, use_checkpoint, use_post_se=False, deploy=False):
147
+ super().__init__()
148
+ strides = [stride] + [1] * (num_blocks - 1)
149
+ blocks = []
150
+ self.in_planes = in_planes
151
+ for stride in strides:
152
+ cur_groups = 1
153
+ blocks.append(RepVGGplusBlock(in_channels=self.in_planes, out_channels=planes, kernel_size=3,
154
+ stride=stride, padding=1, groups=cur_groups, deploy=deploy, use_post_se=use_post_se))
155
+ self.in_planes = planes
156
+ self.blocks = nn.ModuleList(blocks)
157
+ self.use_checkpoint = use_checkpoint
158
+
159
+ def forward(self, x, L2):
160
+ for block in self.blocks:
161
+ if self.use_checkpoint:
162
+ x, L2 = checkpoint.checkpoint(block, x, L2)
163
+ else:
164
+ x, L2 = block(x, L2)
165
+ return x, L2
166
+
167
+
168
+ class RepVGGplus(nn.Module):
169
+
170
+ def __init__(self, num_blocks, num_classes,
171
+ width_multiplier, override_groups_map=None,
172
+ deploy=False,
173
+ use_post_se=False,
174
+ use_checkpoint=False):
175
+ super().__init__()
176
+
177
+ self.deploy = deploy
178
+ self.override_groups_map = override_groups_map or dict()
179
+ self.use_post_se = use_post_se
180
+ self.use_checkpoint = use_checkpoint
181
+ self.num_classes = num_classes
182
+ self.nonlinear = 'relu'
183
+
184
+ self.in_planes = min(64, int(64 * width_multiplier[0]))
185
+ self.stage0 = RepVGGplusBlock(in_channels=3, out_channels=self.in_planes, kernel_size=3, stride=2, padding=1, deploy=self.deploy, use_post_se=use_post_se)
186
+ self.cur_layer_idx = 1
187
+ self.stage1 = RepVGGplusStage(self.in_planes, int(64 * width_multiplier[0]), num_blocks[0], stride=2, use_checkpoint=use_checkpoint, use_post_se=use_post_se, deploy=deploy)
188
+ self.stage2 = RepVGGplusStage(int(64 * width_multiplier[0]), int(128 * width_multiplier[1]), num_blocks[1], stride=2, use_checkpoint=use_checkpoint, use_post_se=use_post_se, deploy=deploy)
189
+ # split stage3 so that we can insert an auxiliary classifier
190
+ self.stage3_first = RepVGGplusStage(int(128 * width_multiplier[1]), int(256 * width_multiplier[2]), num_blocks[2] // 2, stride=2, use_checkpoint=use_checkpoint, use_post_se=use_post_se, deploy=deploy)
191
+ self.stage3_second = RepVGGplusStage(int(256 * width_multiplier[2]), int(256 * width_multiplier[2]), num_blocks[2] // 2, stride=1, use_checkpoint=use_checkpoint, use_post_se=use_post_se, deploy=deploy)
192
+ self.stage4 = RepVGGplusStage(int(256 * width_multiplier[2]), int(512 * width_multiplier[3]), num_blocks[3], stride=2, use_checkpoint=use_checkpoint, use_post_se=use_post_se, deploy=deploy)
193
+ self.gap = nn.AdaptiveAvgPool2d(output_size=1)
194
+ self.linear = nn.Linear(int(512 * width_multiplier[3]), num_classes)
195
+ # aux classifiers
196
+ if not self.deploy:
197
+ self.stage1_aux = self._build_aux_for_stage(self.stage1)
198
+ self.stage2_aux = self._build_aux_for_stage(self.stage2)
199
+ self.stage3_first_aux = self._build_aux_for_stage(self.stage3_first)
200
+
201
+ def _build_aux_for_stage(self, stage):
202
+ stage_out_channels = list(stage.blocks.children())[-1].rbr_dense.conv.out_channels
203
+ downsample = conv_bn_relu(in_channels=stage_out_channels, out_channels=stage_out_channels, kernel_size=3, stride=2, padding=1)
204
+ fc = nn.Linear(stage_out_channels, self.num_classes, bias=True)
205
+ return nn.Sequential(downsample, nn.AdaptiveAvgPool2d(1), nn.Flatten(), fc)
206
+
207
+ def forward(self, x):
208
+ if self.deploy:
209
+ out, _ = self.stage0(x, L2=None)
210
+ out, _ = self.stage1(out, L2=None)
211
+ out, _ = self.stage2(out, L2=None)
212
+ out, _ = self.stage3_first(out, L2=None)
213
+ out, _ = self.stage3_second(out, L2=None)
214
+ out, _ = self.stage4(out, L2=None)
215
+ y = self.gap(out)
216
+ y = y.view(y.size(0), -1)
217
+ y = self.linear(y)
218
+ return y
219
+
220
+ else:
221
+ out, L2 = self.stage0(x, L2=0.0)
222
+ out, L2 = self.stage1(out, L2=L2)
223
+ stage1_aux = self.stage1_aux(out)
224
+ out, L2 = self.stage2(out, L2=L2)
225
+ stage2_aux = self.stage2_aux(out)
226
+ out, L2 = self.stage3_first(out, L2=L2)
227
+ stage3_first_aux = self.stage3_first_aux(out)
228
+ out, L2 = self.stage3_second(out, L2=L2)
229
+ out, L2 = self.stage4(out, L2=L2)
230
+ y = self.gap(out)
231
+ y = y.view(y.size(0), -1)
232
+ y = self.linear(y)
233
+ return {
234
+ 'main': y,
235
+ 'stage1_aux': stage1_aux,
236
+ 'stage2_aux': stage2_aux,
237
+ 'stage3_first_aux': stage3_first_aux,
238
+ 'L2': L2
239
+ }
240
+
241
+ def switch_repvggplus_to_deploy(self):
242
+ for m in self.modules():
243
+ if hasattr(m, 'switch_to_deploy'):
244
+ m.switch_to_deploy()
245
+ if hasattr(m, 'use_checkpoint'):
246
+ m.use_checkpoint = False # Disable checkpoint. I am not sure whether using checkpoint slows down inference.
247
+ if hasattr(self, 'stage1_aux'):
248
+ self.__delattr__('stage1_aux')
249
+ if hasattr(self, 'stage2_aux'):
250
+ self.__delattr__('stage2_aux')
251
+ if hasattr(self, 'stage3_first_aux'):
252
+ self.__delattr__('stage3_first_aux')
253
+ self.deploy = True
254
+
255
+
256
+ # torch.utils.checkpoint can reduce the memory consumption during training with a minor slowdown. Don't use it if you have sufficient GPU memory.
257
+ # Not sure whether it slows down inference
258
+ # pse for "post SE", which means using SE block after ReLU
259
+ def create_RepVGGplus_L2pse(deploy=False, use_checkpoint=False):
260
+ return RepVGGplus(num_blocks=[8, 14, 24, 1], num_classes=1000,
261
+ width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_post_se=True,
262
+ use_checkpoint=use_checkpoint)
263
+
264
+ repvggplus_func_dict = {
265
+ 'RepVGGplus-L2pse': create_RepVGGplus_L2pse,
266
+ }
267
+ def get_RepVGGplus_func_by_name(name):
268
+ return repvggplus_func_dict[name]
RepVGG-main/se_block.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # https://openaccess.thecvf.com/content_cvpr_2018/html/Hu_Squeeze-and-Excitation_Networks_CVPR_2018_paper.html
6
+
7
+ class SEBlock(nn.Module):
8
+
9
+ def __init__(self, input_channels, internal_neurons):
10
+ super(SEBlock, self).__init__()
11
+ self.down = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons, kernel_size=1, stride=1, bias=True)
12
+ self.up = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels, kernel_size=1, stride=1, bias=True)
13
+ self.input_channels = input_channels
14
+
15
+ def forward(self, inputs):
16
+ x = F.avg_pool2d(inputs, kernel_size=inputs.size(3))
17
+ x = self.down(x)
18
+ x = F.relu(x)
19
+ x = self.up(x)
20
+ x = torch.sigmoid(x)
21
+ x = x.view(-1, self.input_channels, 1, 1)
22
+ return inputs * x
RepVGG-main/speed_acc.PNG ADDED
RepVGG-main/table.PNG ADDED
RepVGG-main/tools/convert.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf)
3
+ # Github source: https://github.com/DingXiaoH/RepVGG
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ import argparse
7
+ import os
8
+ import torch
9
+ import torch.nn.parallel
10
+ import torch.optim
11
+ import torch.utils.data
12
+ import torch.utils.data.distributed
13
+ from repvggplus import create_RepVGGplus_by_name, repvgg_model_convert
14
+
15
+ parser = argparse.ArgumentParser(description='RepVGG(plus) Conversion')
16
+ parser.add_argument('load', metavar='LOAD', help='path to the weights file')
17
+ parser.add_argument('save', metavar='SAVE', help='path to the weights file')
18
+ parser.add_argument('-a', '--arch', metavar='ARCH', default='RepVGG-A0')
19
+
20
+ def convert():
21
+ args = parser.parse_args()
22
+
23
+ train_model = create_RepVGGplus_by_name(args.arch, deploy=False)
24
+
25
+ if os.path.isfile(args.load):
26
+ print("=> loading checkpoint '{}'".format(args.load))
27
+ checkpoint = torch.load(args.load)
28
+ if 'state_dict' in checkpoint:
29
+ checkpoint = checkpoint['state_dict']
30
+ elif 'model' in checkpoint:
31
+ checkpoint = checkpoint['model']
32
+ ckpt = {k.replace('module.', ''): v for k, v in checkpoint.items()} # strip the names
33
+ print(ckpt.keys())
34
+ train_model.load_state_dict(ckpt)
35
+ else:
36
+ print("=> no checkpoint found at '{}'".format(args.load))
37
+
38
+ if 'plus' in args.arch:
39
+ train_model.switch_repvggplus_to_deploy()
40
+ torch.save(train_model.state_dict(), args.save)
41
+ else:
42
+ repvgg_model_convert(train_model, save_path=args.save)
43
+
44
+
45
+ if __name__ == '__main__':
46
+ convert()
RepVGG-main/tools/insert_bn.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf)
3
+ # Github source: https://github.com/DingXiaoH/RepVGG
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ import argparse
7
+ import os
8
+ import time
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.parallel
12
+ import torch.backends.cudnn as cudnn
13
+ import torch.optim
14
+ import torch.utils.data
15
+ import torch.utils.data.distributed
16
+ from utils import accuracy, ProgressMeter, AverageMeter
17
+ from repvgg import get_RepVGG_func_by_name, RepVGGBlock
18
+ from utils import load_checkpoint, get_ImageNet_train_dataset, get_default_train_trans
19
+
20
+ # Insert BN into an inference-time RepVGG (e.g., for quantization-aware training).
21
+ # Get the mean and std on every conv3x3 (before the bias-adding) on the train set. Then use such data to initialize BN layers and insert them after conv3x3.
22
+ # May, 07, 2021
23
+
24
+ parser = argparse.ArgumentParser(description='Get the mean and std on every conv3x3 (before the bias-adding) on the train set. Then use such data to initialize BN layers and insert them after conv3x3.')
25
+ parser.add_argument('data', metavar='DIR', help='path to dataset')
26
+ parser.add_argument('weights', metavar='WEIGHTS', help='path to the weights file')
27
+ parser.add_argument('save', metavar='SAVE', help='path to save the model with BN')
28
+ parser.add_argument('-a', '--arch', metavar='ARCH', default='RepVGG-A0')
29
+ parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
30
+ help='number of data loading workers (default: 4)')
31
+ parser.add_argument('-b', '--batch-size', default=100, type=int,
32
+ metavar='N',
33
+ help='mini-batch size (default: 100) for test')
34
+ parser.add_argument('-n', '--num-batches', default=500, type=int,
35
+ metavar='N',
36
+ help='number of batches (default: 500) to record the mean and std on the train set')
37
+ parser.add_argument('-r', '--resolution', default=224, type=int,
38
+ metavar='R',
39
+ help='resolution (default: 224) for test')
40
+
41
+
42
+ def update_running_mean_var(x, running_mean, running_var, momentum=0.9, is_first_batch=False):
43
+ mean = x.mean(dim=(0, 2, 3), keepdim=True)
44
+ var = ((x - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
45
+ if is_first_batch:
46
+ running_mean = mean
47
+ running_var = var
48
+ else:
49
+ running_mean = momentum * running_mean + (1.0 - momentum) * mean
50
+ running_var = momentum * running_var + (1.0 - momentum) * var
51
+ return running_mean, running_var
52
+
53
+ # Record the mean and std like a BN layer but do no normalization
54
+ class BNStatistics(nn.Module):
55
+ def __init__(self, num_features):
56
+ super(BNStatistics, self).__init__()
57
+ shape = (1, num_features, 1, 1)
58
+ self.register_buffer('running_mean', torch.zeros(shape))
59
+ self.register_buffer('running_var', torch.zeros(shape))
60
+ self.is_first_batch = True
61
+
62
+ def forward(self, x):
63
+ if self.running_mean.device != x.device:
64
+ self.running_mean = self.running_mean.to(x.device)
65
+ self.running_var = self.running_var.to(x.device)
66
+ self.running_mean, self.running_var = update_running_mean_var(x, self.running_mean, self.running_var, momentum=0.9, is_first_batch=self.is_first_batch)
67
+ self.is_first_batch = False
68
+ return x
69
+
70
+ # This is designed to insert BNStat layer between Conv2d(without bias) and its bias
71
+ class BiasAdd(nn.Module):
72
+ def __init__(self, num_features):
73
+ super(BiasAdd, self).__init__()
74
+ self.bias = torch.nn.Parameter(torch.Tensor(num_features))
75
+ def forward(self, x):
76
+ return x + self.bias.view(1, -1, 1, 1)
77
+
78
+ def switch_repvggblock_to_bnstat(model):
79
+ for n, block in model.named_modules():
80
+ if isinstance(block, RepVGGBlock):
81
+ print('switch to BN Statistics: ', n)
82
+ assert hasattr(block, 'rbr_reparam')
83
+ stat = nn.Sequential()
84
+ stat.add_module('conv', nn.Conv2d(block.rbr_reparam.in_channels, block.rbr_reparam.out_channels,
85
+ block.rbr_reparam.kernel_size,
86
+ block.rbr_reparam.stride, block.rbr_reparam.padding,
87
+ block.rbr_reparam.dilation,
88
+ block.rbr_reparam.groups, bias=False)) # Note bias=False
89
+ stat.add_module('bnstat', BNStatistics(block.rbr_reparam.out_channels))
90
+ stat.add_module('biasadd', BiasAdd(block.rbr_reparam.out_channels)) # Bias is here
91
+ stat.conv.weight.data = block.rbr_reparam.weight.data
92
+ stat.biasadd.bias.data = block.rbr_reparam.bias.data
93
+ block.__delattr__('rbr_reparam')
94
+ block.rbr_reparam = stat
95
+
96
+ def switch_bnstat_to_convbn(model):
97
+ for n, block in model.named_modules():
98
+ if isinstance(block, RepVGGBlock):
99
+ assert hasattr(block, 'rbr_reparam')
100
+ assert hasattr(block.rbr_reparam, 'bnstat')
101
+ print('switch to ConvBN: ', n)
102
+ conv = nn.Conv2d(block.rbr_reparam.conv.in_channels, block.rbr_reparam.conv.out_channels,
103
+ block.rbr_reparam.conv.kernel_size,
104
+ block.rbr_reparam.conv.stride, block.rbr_reparam.conv.padding,
105
+ block.rbr_reparam.conv.dilation,
106
+ block.rbr_reparam.conv.groups, bias=False)
107
+ bn = nn.BatchNorm2d(block.rbr_reparam.conv.out_channels)
108
+ bn.running_mean = block.rbr_reparam.bnstat.running_mean.squeeze() # Initialize the mean and var of BN with the statistics
109
+ bn.running_var = block.rbr_reparam.bnstat.running_var.squeeze()
110
+ std = (bn.running_var + bn.eps).sqrt()
111
+ conv.weight.data = block.rbr_reparam.conv.weight.data
112
+ bn.weight.data = std
113
+ bn.bias.data = block.rbr_reparam.biasadd.bias.data + bn.running_mean # Initialize gamma = std and beta = bias + mean
114
+
115
+ convbn = nn.Sequential()
116
+ convbn.add_module('conv', conv)
117
+ convbn.add_module('bn', bn)
118
+ block.__delattr__('rbr_reparam')
119
+ block.rbr_reparam = convbn
120
+
121
+
122
+ # Insert a BN after conv3x3 (rbr_reparam). With no reasonable initialization of BN, the model may break down.
123
+ # So you have to load the weights obtained through the BN statistics (please see the function "insert_bn" in this file).
124
+ def directly_insert_bn_without_init(model):
125
+ for n, block in model.named_modules():
126
+ if isinstance(block, RepVGGBlock):
127
+ print('directly insert a BN with no initialization: ', n)
128
+ assert hasattr(block, 'rbr_reparam')
129
+ convbn = nn.Sequential()
130
+ convbn.add_module('conv', nn.Conv2d(block.rbr_reparam.in_channels, block.rbr_reparam.out_channels,
131
+ block.rbr_reparam.kernel_size,
132
+ block.rbr_reparam.stride, block.rbr_reparam.padding,
133
+ block.rbr_reparam.dilation,
134
+ block.rbr_reparam.groups, bias=False)) # Note bias=False
135
+ convbn.add_module('bn', nn.BatchNorm2d(block.rbr_reparam.out_channels))
136
+ # ====================
137
+ convbn.add_module('relu', nn.ReLU())
138
+ # TODO we moved ReLU from "block.nonlinearity" into "rbr_reparam" (nn.Sequential). This makes it more convenient to fuse operators (see RepVGGWholeQuant.fuse_model) using off-the-shelf APIs.
139
+ block.nonlinearity = nn.Identity()
140
+ #==========================
141
+ block.__delattr__('rbr_reparam')
142
+ block.rbr_reparam = convbn
143
+
144
+
145
+ def insert_bn():
146
+ args = parser.parse_args()
147
+
148
+ repvgg_build_func = get_RepVGG_func_by_name(args.arch)
149
+
150
+ model = repvgg_build_func(deploy=True).cuda()
151
+
152
+ load_checkpoint(model, args.weights)
153
+
154
+ switch_repvggblock_to_bnstat(model)
155
+
156
+ cudnn.benchmark = True
157
+
158
+ trans = get_default_train_trans(args)
159
+ print('data aug: ', trans)
160
+
161
+ train_dataset = get_ImageNet_train_dataset(args, trans)
162
+
163
+ train_loader = torch.utils.data.DataLoader(
164
+ train_dataset,
165
+ batch_size=args.batch_size, shuffle=False,
166
+ num_workers=args.workers, pin_memory=True)
167
+
168
+ batch_time = AverageMeter('Time', ':6.3f')
169
+ losses = AverageMeter('Loss', ':.4e')
170
+ top1 = AverageMeter('Acc@1', ':6.2f')
171
+ top5 = AverageMeter('Acc@5', ':6.2f')
172
+
173
+ progress = ProgressMeter(
174
+ min(len(train_loader), args.num_batches),
175
+ [batch_time, losses, top1, top5],
176
+ prefix='BN stat: ')
177
+
178
+ criterion = nn.CrossEntropyLoss().cuda()
179
+
180
+ with torch.no_grad():
181
+ end = time.time()
182
+ for i, (images, target) in enumerate(train_loader):
183
+ if i >= args.num_batches:
184
+ break
185
+ images = images.cuda(non_blocking=True)
186
+ target = target.cuda(non_blocking=True)
187
+
188
+ # compute output
189
+ output = model(images)
190
+ loss = criterion(output, target)
191
+
192
+ # measure accuracy and record loss
193
+ acc1, acc5 = accuracy(output, target, topk=(1, 5))
194
+ losses.update(loss.item(), images.size(0))
195
+ top1.update(acc1[0], images.size(0))
196
+ top5.update(acc5[0], images.size(0))
197
+
198
+ # measure elapsed time
199
+ batch_time.update(time.time() - end)
200
+ end = time.time()
201
+
202
+ if i % 10 == 0:
203
+ progress.display(i)
204
+
205
+
206
+ print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
207
+ .format(top1=top1, top5=top5))
208
+
209
+ switch_bnstat_to_convbn(model)
210
+
211
+ torch.save(model.state_dict(), args.save)
212
+
213
+
214
+
215
+
216
+ if __name__ == '__main__':
217
+ insert_bn()
RepVGG-main/tools/verify.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf)
3
+ # Github source: https://github.com/DingXiaoH/RepVGG
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ import torch
7
+ import torch.nn as nn
8
+ from repvgg import create_RepVGG_B1
9
+
10
+ if __name__ == '__main__':
11
+ x = torch.randn(1, 3, 224, 224)
12
+ model = create_RepVGG_B1(deploy=False)
13
+ model.eval()
14
+
15
+ for module in model.modules():
16
+ if isinstance(module, torch.nn.BatchNorm2d):
17
+ nn.init.uniform_(module.running_mean, 0, 0.1)
18
+ nn.init.uniform_(module.running_var, 0, 0.1)
19
+ nn.init.uniform_(module.weight, 0, 0.1)
20
+ nn.init.uniform_(module.bias, 0, 0.1)
21
+
22
+ train_y = model(x)
23
+ for module in model.modules():
24
+ if hasattr(module, 'switch_to_deploy'):
25
+ module.switch_to_deploy()
26
+
27
+ print(model)
28
+ deploy_y = model(x)
29
+ print('========================== The diff is')
30
+ print(((train_y - deploy_y) ** 2).sum())
RepVGG-main/train/config.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf)
3
+ # Github source: https://github.com/DingXiaoH/RepVGG
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # The training script is based on the code of Swin Transformer (https://github.com/microsoft/Swin-Transformer)
6
+ # --------------------------------------------------------
7
+
8
+ import os
9
+ import yaml
10
+ from yacs.config import CfgNode as CN
11
+
12
+ _C = CN()
13
+
14
+ # Base config files
15
+ _C.BASE = ['']
16
+
17
+ # -----------------------------------------------------------------------------
18
+ # Data settings
19
+ # -----------------------------------------------------------------------------
20
+ _C.DATA = CN()
21
+ # Batch size for a single GPU, could be overwritten by command line argument
22
+ _C.DATA.BATCH_SIZE = 128
23
+ # Path to dataset, could be overwritten by command line argument
24
+ _C.DATA.DATA_PATH = '/your/path/to/dataset'
25
+
26
+ # Dataset name
27
+ _C.DATA.DATASET = 'imagenet'
28
+ # Input image size
29
+ _C.DATA.IMG_SIZE = 224
30
+ _C.DATA.TEST_SIZE = None
31
+ _C.DATA.TEST_BATCH_SIZE = None
32
+ # Interpolation to resize image (random, bilinear, bicubic)
33
+ _C.DATA.INTERPOLATION = 'bilinear'
34
+ # Use zipped dataset instead of folder dataset
35
+ # could be overwritten by command line argument
36
+ _C.DATA.ZIP_MODE = False
37
+ # Cache Data in Memory, could be overwritten by command line argument
38
+ _C.DATA.CACHE_MODE = 'part'
39
+ # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
40
+ _C.DATA.PIN_MEMORY = True
41
+ # Number of data loading threads
42
+ _C.DATA.NUM_WORKERS = 8
43
+
44
+ # -----------------------------------------------------------------------------
45
+ # Model settings
46
+ # -----------------------------------------------------------------------------
47
+ _C.MODEL = CN()
48
+ # Model type
49
+ _C.MODEL.ARCH = 'RepVGG-L2pse'
50
+ # Checkpoint to resume, could be overwritten by command line argument
51
+ _C.MODEL.RESUME = ''
52
+ # Number of classes, overwritten in data preparation
53
+ _C.MODEL.NUM_CLASSES = 1000
54
+ # Label Smoothing
55
+ _C.MODEL.LABEL_SMOOTHING = 0.1
56
+
57
+ # -----------------------------------------------------------------------------
58
+ # Training settings
59
+ # -----------------------------------------------------------------------------
60
+ _C.TRAIN = CN()
61
+ _C.TRAIN.START_EPOCH = 0
62
+ _C.TRAIN.EPOCHS = 300
63
+ _C.TRAIN.WARMUP_EPOCHS = 20
64
+ _C.TRAIN.WEIGHT_DECAY = 0.05
65
+ _C.TRAIN.BASE_LR = 5e-4
66
+ _C.TRAIN.WARMUP_LR = 0.0
67
+ _C.TRAIN.MIN_LR = 0.0
68
+ # Clip gradient norm
69
+ _C.TRAIN.CLIP_GRAD = 0.0
70
+ # Auto resume from latest checkpoint
71
+ _C.TRAIN.AUTO_RESUME = True
72
+ # Gradient accumulation steps
73
+ # could be overwritten by command line argument
74
+ _C.TRAIN.ACCUMULATION_STEPS = 0
75
+ # Whether to use gradient checkpointing to save memory
76
+ # could be overwritten by command line argument
77
+ _C.TRAIN.USE_CHECKPOINT = False
78
+
79
+ # LR scheduler
80
+ _C.TRAIN.LR_SCHEDULER = CN()
81
+ _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
82
+ # Epoch interval to decay LR, used in StepLRScheduler
83
+ _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
84
+ # LR decay rate, used in StepLRScheduler
85
+ _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
86
+
87
+ # Optimizer
88
+ _C.TRAIN.OPTIMIZER = CN()
89
+ _C.TRAIN.OPTIMIZER.NAME = 'sgd'
90
+ # Optimizer Epsilon
91
+ _C.TRAIN.OPTIMIZER.EPS = 1e-8
92
+ # Optimizer Betas
93
+ _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
94
+ # SGD momentum
95
+ _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
96
+
97
+ # For EMA model
98
+ _C.TRAIN.EMA_ALPHA = 0.0
99
+ _C.TRAIN.EMA_UPDATE_PERIOD = 8
100
+
101
+ # For RepOptimizer only
102
+ _C.TRAIN.SCALES_PATH = None
103
+
104
+ # -----------------------------------------------------------------------------
105
+ # Augmentation settings
106
+ # -----------------------------------------------------------------------------
107
+ _C.AUG = CN()
108
+ # Mixup alpha, mixup enabled if > 0
109
+ _C.AUG.MIXUP = 0.2
110
+ # Cutmix alpha, cutmix enabled if > 0
111
+ _C.AUG.CUTMIX = 0.0
112
+ # Cutmix min/max ratio, overrides alpha and enables cutmix if set
113
+ _C.AUG.CUTMIX_MINMAX = None
114
+ # Probability of performing mixup or cutmix when either/both is enabled
115
+ _C.AUG.MIXUP_PROB = 1.0
116
+ # Probability of switching to cutmix when both mixup and cutmix enabled
117
+ _C.AUG.MIXUP_SWITCH_PROB = 0.5
118
+ # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
119
+ _C.AUG.MIXUP_MODE = 'batch'
120
+
121
+ _C.AUG.PRESET = None # If use AUG.PRESET (e.g., 'raug15'), use the pre-defined preprocessing, ignoring the following settings.
122
+ # Color jitter factor
123
+ _C.AUG.COLOR_JITTER = 0.4
124
+ # Use AutoAugment policy. "v0" or "original"
125
+ _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
126
+ # Random erase prob
127
+ _C.AUG.REPROB = 0.25
128
+ # Random erase mode
129
+ _C.AUG.REMODE = 'pixel'
130
+ # Random erase count
131
+ _C.AUG.RECOUNT = 1
132
+
133
+
134
+ # -----------------------------------------------------------------------------
135
+ # Testing settings
136
+ # -----------------------------------------------------------------------------
137
+ _C.TEST = CN()
138
+ # Whether to use center crop when testing
139
+ _C.TEST.CROP = False
140
+
141
+ # -----------------------------------------------------------------------------
142
+ # Misc
143
+ # -----------------------------------------------------------------------------
144
+ # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
145
+ # overwritten by command line argument
146
+ _C.AMP_OPT_LEVEL = ''
147
+ # Path to output folder, overwritten by command line argument
148
+ _C.OUTPUT = ''
149
+ # Tag of experiment, overwritten by command line argument
150
+ _C.TAG = 'default'
151
+ # Frequency to save checkpoint
152
+ _C.SAVE_FREQ = 20
153
+ # Frequency to logging info
154
+ _C.PRINT_FREQ = 10
155
+ # Fixed random seed
156
+ _C.SEED = 0
157
+ # Perform evaluation only, overwritten by command line argument
158
+ _C.EVAL_MODE = False
159
+ # Test throughput only, overwritten by command line argument
160
+ _C.THROUGHPUT_MODE = False
161
+ # local rank for DistributedDataParallel, given by command line argument
162
+ _C.LOCAL_RANK = 0
163
+
164
+
165
+ def update_config(config, args):
166
+ config.defrost()
167
+ if args.opts:
168
+ config.merge_from_list(args.opts)
169
+ # merge from specific arguments
170
+ if args.scales_path:
171
+ config.TRAIN.SCALES_PATH = args.scales_path
172
+ if args.arch:
173
+ config.MODEL.ARCH = args.arch
174
+ if args.batch_size:
175
+ config.DATA.BATCH_SIZE = args.batch_size
176
+ if args.data_path:
177
+ config.DATA.DATA_PATH = args.data_path
178
+ if args.zip:
179
+ config.DATA.ZIP_MODE = True
180
+ if args.cache_mode:
181
+ config.DATA.CACHE_MODE = args.cache_mode
182
+ if args.resume:
183
+ config.MODEL.RESUME = args.resume
184
+ if args.accumulation_steps:
185
+ config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
186
+ if args.use_checkpoint:
187
+ config.TRAIN.USE_CHECKPOINT = True
188
+ if args.amp_opt_level:
189
+ config.AMP_OPT_LEVEL = args.amp_opt_level
190
+ if args.output:
191
+ config.OUTPUT = args.output
192
+ if args.tag:
193
+ config.TAG = args.tag
194
+ if args.eval:
195
+ config.EVAL_MODE = True
196
+ if args.throughput:
197
+ config.THROUGHPUT_MODE = True
198
+
199
+ if config.DATA.TEST_SIZE is None:
200
+ config.DATA.TEST_SIZE = config.DATA.IMG_SIZE
201
+ if config.DATA.TEST_BATCH_SIZE is None:
202
+ config.DATA.TEST_BATCH_SIZE = config.DATA.BATCH_SIZE
203
+ # set local rank for distributed training
204
+ config.LOCAL_RANK = args.local_rank
205
+ # output folder
206
+ config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.ARCH, config.TAG)
207
+ config.freeze()
208
+
209
+
210
+ def get_config(args):
211
+ """Get a yacs CfgNode object with default values."""
212
+ # Return a clone so that the defaults will not be altered
213
+ # This is for the "local variable" use pattern
214
+ config = _C.clone()
215
+ update_config(config, args)
216
+
217
+ return config
RepVGG-main/train/cutout.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ class Cutout:
4
+
5
+ def __init__(self, size=16) -> None:
6
+ self.size = size
7
+
8
+ def _create_cutout_mask(self, img_height, img_width, num_channels, size):
9
+ """Creates a zero mask used for cutout of shape `img_height` x `img_width`.
10
+ Args:
11
+ img_height: Height of image cutout mask will be applied to.
12
+ img_width: Width of image cutout mask will be applied to.
13
+ num_channels: Number of channels in the image.
14
+ size: Size of the zeros mask.
15
+ Returns:
16
+ A mask of shape `img_height` x `img_width` with all ones except for a
17
+ square of zeros of shape `size` x `size`. This mask is meant to be
18
+ elementwise multiplied with the original image. Additionally returns
19
+ the `upper_coord` and `lower_coord` which specify where the cutout mask
20
+ will be applied.
21
+ """
22
+ # assert img_height == img_width
23
+
24
+ # Sample center where cutout mask will be applied
25
+ height_loc = np.random.randint(low=0, high=img_height)
26
+ width_loc = np.random.randint(low=0, high=img_width)
27
+
28
+ size = int(size)
29
+ # Determine upper right and lower left corners of patch
30
+ upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2))
31
+ lower_coord = (
32
+ min(img_height, height_loc + size // 2),
33
+ min(img_width, width_loc + size // 2),
34
+ )
35
+ mask_height = lower_coord[0] - upper_coord[0]
36
+ mask_width = lower_coord[1] - upper_coord[1]
37
+ assert mask_height > 0
38
+ assert mask_width > 0
39
+
40
+ mask = np.ones((img_height, img_width, num_channels))
41
+ zeros = np.zeros((mask_height, mask_width, num_channels))
42
+ mask[upper_coord[0]: lower_coord[0], upper_coord[1]: lower_coord[1], :] = zeros
43
+ return mask, upper_coord, lower_coord
44
+
45
+ def __call__(self, pil_img):
46
+ pil_img = pil_img.copy()
47
+ img_height, img_width, num_channels = (*pil_img.size, 3)
48
+ _, upper_coord, lower_coord = self._create_cutout_mask(
49
+ img_height, img_width, num_channels, self.size
50
+ )
51
+ pixels = pil_img.load() # create the pixel map
52
+ for i in range(upper_coord[0], lower_coord[0]): # for every col:
53
+ for j in range(upper_coord[1], lower_coord[1]): # For every row
54
+ pixels[i, j] = (125, 122, 113, 0) # set the colour accordingly
55
+ return pil_img
RepVGG-main/train/logger.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf)
3
+ # Github source: https://github.com/DingXiaoH/RepVGG
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # The training script is based on the code of Swin Transformer (https://github.com/microsoft/Swin-Transformer)
6
+ # --------------------------------------------------------
7
+
8
+ import os
9
+ import sys
10
+ import logging
11
+ import functools
12
+ from termcolor import colored
13
+
14
+
15
+ @functools.lru_cache()
16
+ def create_logger(output_dir, dist_rank=0, name=''):
17
+ # create logger
18
+ logger = logging.getLogger(name)
19
+ logger.setLevel(logging.DEBUG)
20
+ logger.propagate = False
21
+
22
+ # create formatter
23
+ fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
24
+ color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
25
+ colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
26
+
27
+ # create console handlers for master process
28
+ if dist_rank == 0:
29
+ console_handler = logging.StreamHandler(sys.stdout)
30
+ console_handler.setLevel(logging.DEBUG)
31
+ console_handler.setFormatter(
32
+ logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
33
+ logger.addHandler(console_handler)
34
+
35
+ # create file handlers
36
+ file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a')
37
+ file_handler.setLevel(logging.DEBUG)
38
+ file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
39
+ logger.addHandler(file_handler)
40
+
41
+ return logger
RepVGG-main/train/lr_scheduler.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf)
3
+ # Github source: https://github.com/DingXiaoH/RepVGG
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # The training script is based on the code of Swin Transformer (https://github.com/microsoft/Swin-Transformer)
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ from timm.scheduler.cosine_lr import CosineLRScheduler
10
+ from timm.scheduler.step_lr import StepLRScheduler
11
+ from timm.scheduler.scheduler import Scheduler
12
+
13
+
14
+ def build_scheduler(config, optimizer, n_iter_per_epoch):
15
+ num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)
16
+ warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)
17
+ decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch)
18
+
19
+ lr_scheduler = None
20
+ if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
21
+ lr_scheduler = CosineLRScheduler(
22
+ optimizer,
23
+ t_initial=num_steps,
24
+ lr_min=config.TRAIN.MIN_LR,
25
+ warmup_lr_init=config.TRAIN.WARMUP_LR,
26
+ warmup_t=warmup_steps,
27
+ cycle_limit=1,
28
+ t_in_epochs=False,
29
+ )
30
+ elif config.TRAIN.LR_SCHEDULER.NAME == 'linear':
31
+ lr_scheduler = LinearLRScheduler(
32
+ optimizer,
33
+ t_initial=num_steps,
34
+ lr_min_rate=0.01,
35
+ warmup_lr_init=config.TRAIN.WARMUP_LR,
36
+ warmup_t=warmup_steps,
37
+ t_in_epochs=False,
38
+ )
39
+ elif config.TRAIN.LR_SCHEDULER.NAME == 'step':
40
+ lr_scheduler = StepLRScheduler(
41
+ optimizer,
42
+ decay_t=decay_steps,
43
+ decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
44
+ warmup_lr_init=config.TRAIN.WARMUP_LR,
45
+ warmup_t=warmup_steps,
46
+ t_in_epochs=False,
47
+ )
48
+
49
+ return lr_scheduler
50
+
51
+
52
+ class LinearLRScheduler(Scheduler):
53
+ def __init__(self,
54
+ optimizer: torch.optim.Optimizer,
55
+ t_initial: int,
56
+ lr_min_rate: float,
57
+ warmup_t=0,
58
+ warmup_lr_init=0.,
59
+ t_in_epochs=True,
60
+ noise_range_t=None,
61
+ noise_pct=0.67,
62
+ noise_std=1.0,
63
+ noise_seed=42,
64
+ initialize=True,
65
+ ) -> None:
66
+ super().__init__(
67
+ optimizer, param_group_field="lr",
68
+ noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
69
+ initialize=initialize)
70
+
71
+ self.t_initial = t_initial
72
+ self.lr_min_rate = lr_min_rate
73
+ self.warmup_t = warmup_t
74
+ self.warmup_lr_init = warmup_lr_init
75
+ self.t_in_epochs = t_in_epochs
76
+ if self.warmup_t:
77
+ self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
78
+ super().update_groups(self.warmup_lr_init)
79
+ else:
80
+ self.warmup_steps = [1 for _ in self.base_values]
81
+
82
+ def _get_lr(self, t):
83
+ if t < self.warmup_t:
84
+ lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
85
+ else:
86
+ t = t - self.warmup_t
87
+ total_t = self.t_initial - self.warmup_t
88
+ lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]
89
+ return lrs
90
+
91
+ def get_epoch_values(self, epoch: int):
92
+ if self.t_in_epochs:
93
+ return self._get_lr(epoch)
94
+ else:
95
+ return None
96
+
97
+ def get_update_values(self, num_updates: int):
98
+ if not self.t_in_epochs:
99
+ return self._get_lr(num_updates)
100
+ else:
101
+ return None
RepVGG-main/train/optimizer.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf)
3
+ # Github source: https://github.com/DingXiaoH/RepVGG
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # The training script is based on the code of Swin Transformer (https://github.com/microsoft/Swin-Transformer)
6
+ # --------------------------------------------------------
7
+
8
+ from torch import optim as optim
9
+
10
+
11
+ def build_optimizer(config, model):
12
+ """
13
+ Build optimizer, set weight decay of normalization to 0 by default.
14
+ """
15
+ skip = {}
16
+ skip_keywords = {}
17
+ if hasattr(model, 'no_weight_decay'):
18
+ skip = model.no_weight_decay()
19
+ if hasattr(model, 'no_weight_decay_keywords'):
20
+ skip_keywords = model.no_weight_decay_keywords()
21
+ echo = (config.LOCAL_RANK==0)
22
+ parameters = set_weight_decay(model, skip, skip_keywords, echo=echo)
23
+ opt_lower = config.TRAIN.OPTIMIZER.NAME.lower()
24
+ optimizer = None
25
+ if opt_lower == 'sgd':
26
+ optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
27
+ lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
28
+ if echo:
29
+ print('================================== SGD nest, momentum = {}, wd = {}'.format(config.TRAIN.OPTIMIZER.MOMENTUM, config.TRAIN.WEIGHT_DECAY))
30
+ elif opt_lower == 'adam':
31
+ print('adam')
32
+ optimizer = optim.Adam(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
33
+ lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
34
+ elif opt_lower == 'adamw':
35
+ optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
36
+ lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
37
+
38
+ return optimizer
39
+
40
+
41
+ def set_weight_decay(model, skip_list=(), skip_keywords=(), echo=False):
42
+ has_decay = []
43
+ no_decay = []
44
+
45
+ for name, param in model.named_parameters():
46
+ if not param.requires_grad:
47
+ continue # frozen weights
48
+ if 'identity.weight' in name:
49
+ has_decay.append(param)
50
+ if echo:
51
+ print(f"{name} USE weight decay")
52
+ elif len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \
53
+ check_keywords_in_name(name, skip_keywords):
54
+ no_decay.append(param)
55
+ if echo:
56
+ print(f"{name} has no weight decay")
57
+ else:
58
+ has_decay.append(param)
59
+ if echo:
60
+ print(f"{name} USE weight decay")
61
+
62
+ return [{'params': has_decay},
63
+ {'params': no_decay, 'weight_decay': 0.}]
64
+
65
+
66
+ def check_keywords_in_name(name, keywords=()):
67
+ isin = False
68
+ for keyword in keywords:
69
+ if keyword in name:
70
+ isin = True
71
+ return isin
RepVGG-main/train/randaug.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+
4
+ import numpy as np
5
+ import PIL
6
+ from PIL import Image, ImageEnhance, ImageOps
7
+
8
+ from train.cutout import Cutout
9
+
10
+
11
+ _PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
12
+
13
+ _FILL = (128, 128, 128)
14
+
15
+ # This signifies the max integer that the controller RNN could predict for the
16
+ # augmentation scheme.
17
+ _MAX_LEVEL = 10.
18
+
19
+ _HPARAMS_DEFAULT = dict(
20
+ translate_const=250,
21
+ img_mean=_FILL,
22
+ )
23
+
24
+ _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
25
+
26
+
27
+ def _interpolation(kwargs):
28
+ interpolation = kwargs.pop('resample', Image.BILINEAR)
29
+ if isinstance(interpolation, (list, tuple)):
30
+ return random.choice(interpolation)
31
+ else:
32
+ return interpolation
33
+
34
+
35
+ def _check_args_tf(kwargs):
36
+ if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
37
+ kwargs.pop('fillcolor')
38
+ kwargs['resample'] = _interpolation(kwargs)
39
+
40
+
41
+ def cutout(img, factor, **kwargs):
42
+ _check_args_tf(kwargs)
43
+ return Cutout(size=factor)(img)
44
+
45
+
46
+ def shear_x(img, factor, **kwargs):
47
+ _check_args_tf(kwargs)
48
+ return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
49
+
50
+
51
+ def shear_y(img, factor, **kwargs):
52
+ _check_args_tf(kwargs)
53
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
54
+
55
+
56
+ def translate_x_rel(img, pct, **kwargs):
57
+ pixels = pct * img.size[0]
58
+ _check_args_tf(kwargs)
59
+ return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
60
+
61
+
62
+ def translate_y_rel(img, pct, **kwargs):
63
+ pixels = pct * img.size[1]
64
+ _check_args_tf(kwargs)
65
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
66
+
67
+
68
+ def translate_x_abs(img, pixels, **kwargs):
69
+ _check_args_tf(kwargs)
70
+ return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
71
+
72
+
73
+ def translate_y_abs(img, pixels, **kwargs):
74
+ _check_args_tf(kwargs)
75
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
76
+
77
+
78
+ def rotate(img, degrees, **kwargs):
79
+ _check_args_tf(kwargs)
80
+ if _PIL_VER >= (5, 2):
81
+ return img.rotate(degrees, **kwargs)
82
+ elif _PIL_VER >= (5, 0):
83
+ w, h = img.size
84
+ post_trans = (0, 0)
85
+ rotn_center = (w / 2.0, h / 2.0)
86
+ angle = -math.radians(degrees)
87
+ matrix = [
88
+ round(math.cos(angle), 15),
89
+ round(math.sin(angle), 15),
90
+ 0.0,
91
+ round(-math.sin(angle), 15),
92
+ round(math.cos(angle), 15),
93
+ 0.0,
94
+ ]
95
+
96
+ def transform(x, y, matrix):
97
+ (a, b, c, d, e, f) = matrix
98
+ return a * x + b * y + c, d * x + e * y + f
99
+
100
+ matrix[2], matrix[5] = transform(
101
+ -rotn_center[0] - post_trans[0],
102
+ - rotn_center[1] - post_trans[1], matrix
103
+ )
104
+ matrix[2] += rotn_center[0]
105
+ matrix[5] += rotn_center[1]
106
+ return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
107
+ else:
108
+ return img.rotate(degrees, resample=kwargs['resample'])
109
+
110
+
111
+ def auto_contrast(img, **__):
112
+ return ImageOps.autocontrast(img)
113
+
114
+
115
+ def invert(img, **__):
116
+ return ImageOps.invert(img)
117
+
118
+
119
+ def identity(img, **__):
120
+ return img
121
+
122
+
123
+ def equalize(img, **__):
124
+ return ImageOps.equalize(img)
125
+
126
+
127
+ def solarize(img, thresh, **__):
128
+ return ImageOps.solarize(img, thresh)
129
+
130
+
131
+ def solarize_add(img, add, thresh=128, **__):
132
+ lut = []
133
+ for i in range(256):
134
+ if i < thresh:
135
+ lut.append(min(255, i + add))
136
+ else:
137
+ lut.append(i)
138
+ if img.mode in ("L", "RGB"):
139
+ if img.mode == "RGB" and len(lut) == 256:
140
+ lut = lut + lut + lut
141
+ return img.point(lut)
142
+ else:
143
+ return img
144
+
145
+
146
+ def posterize(img, bits_to_keep, **__):
147
+ if bits_to_keep >= 8:
148
+ return img
149
+ return ImageOps.posterize(img, bits_to_keep)
150
+
151
+
152
+ def contrast(img, factor, **__):
153
+ return ImageEnhance.Contrast(img).enhance(factor)
154
+
155
+
156
+ def color(img, factor, **__):
157
+ return ImageEnhance.Color(img).enhance(factor)
158
+
159
+
160
+ def brightness(img, factor, **__):
161
+ return ImageEnhance.Brightness(img).enhance(factor)
162
+
163
+
164
+ def sharpness(img, factor, **__):
165
+ return ImageEnhance.Sharpness(img).enhance(factor)
166
+
167
+
168
+ def _randomly_negate(v):
169
+ """With 50% prob, negate the value"""
170
+ return -v if random.random() > 0.5 else v
171
+
172
+
173
+ def _cutout_level_to_arg(level, _hparams):
174
+ # range [0, 40]
175
+ level = max(2, (level / _MAX_LEVEL) * 40.)
176
+ return level,
177
+
178
+
179
+ def _rotate_level_to_arg(level, _hparams):
180
+ # range [-30, 30]
181
+ level = (level / _MAX_LEVEL) * 30.
182
+ level = _randomly_negate(level)
183
+ return level,
184
+
185
+
186
+ def _enhance_level_to_arg(level, _hparams):
187
+ # range [0.1, 1.9]
188
+ return (level / _MAX_LEVEL) * 1.8 + 0.1,
189
+
190
+
191
+ def _shear_level_to_arg(level, _hparams):
192
+ # range [-0.3, 0.3]
193
+ level = (level / _MAX_LEVEL) * 0.3
194
+ level = _randomly_negate(level)
195
+ return level,
196
+
197
+
198
+ def _translate_abs_level_to_arg(level, hparams):
199
+ translate_const = hparams['translate_const']
200
+ level = (level / _MAX_LEVEL) * float(translate_const)
201
+ level = _randomly_negate(level)
202
+ return level,
203
+
204
+
205
+ def _translate_rel_level_to_arg(level, _hparams):
206
+ # range [-0.45, 0.45]
207
+ level = (level / _MAX_LEVEL) * 0.45
208
+ level = _randomly_negate(level)
209
+ return level,
210
+
211
+
212
+ def _posterize_original_level_to_arg(level, _hparams):
213
+ # As per original AutoAugment paper description
214
+ # range [4, 8], 'keep 4 up to 8 MSB of image'
215
+ return int((level / _MAX_LEVEL) * 4) + 4,
216
+
217
+
218
+ def _posterize_research_level_to_arg(level, _hparams):
219
+ # As per Tensorflow models research and UDA impl
220
+ # range [4, 0], 'keep 4 down to 0 MSB of original image'
221
+ return 4 - int((level / _MAX_LEVEL) * 4),
222
+
223
+
224
+ def _posterize_tpu_level_to_arg(level, _hparams):
225
+ # As per Tensorflow TPU EfficientNet impl
226
+ # range [0, 4], 'keep 0 up to 4 MSB of original image'
227
+ return int((level / _MAX_LEVEL) * 4),
228
+
229
+
230
+ def _solarize_level_to_arg(level, _hparams):
231
+ # range [0, 256]
232
+ return int((level / _MAX_LEVEL) * 256),
233
+
234
+
235
+ def _solarize_add_level_to_arg(level, _hparams):
236
+ # range [0, 110]
237
+ return int((level / _MAX_LEVEL) * 110),
238
+
239
+
240
+ LEVEL_TO_ARG = {
241
+ 'AutoContrast': None,
242
+ 'Equalize': None,
243
+ 'Invert': None,
244
+ 'Identity': None,
245
+ 'Rotate': _rotate_level_to_arg,
246
+ 'PosterizeOriginal': _posterize_original_level_to_arg,
247
+ 'PosterizeResearch': _posterize_research_level_to_arg,
248
+ 'PosterizeTpu': _posterize_tpu_level_to_arg,
249
+ 'Solarize': _solarize_level_to_arg,
250
+ 'SolarizeAdd': _solarize_add_level_to_arg,
251
+ 'Color': _enhance_level_to_arg,
252
+ 'Contrast': _enhance_level_to_arg,
253
+ 'Brightness': _enhance_level_to_arg,
254
+ 'Sharpness': _enhance_level_to_arg,
255
+ 'ShearX': _shear_level_to_arg,
256
+ 'ShearY': _shear_level_to_arg,
257
+ 'TranslateX': _translate_abs_level_to_arg,
258
+ 'TranslateY': _translate_abs_level_to_arg,
259
+ 'TranslateXRel': _translate_rel_level_to_arg,
260
+ 'TranslateYRel': _translate_rel_level_to_arg,
261
+ 'Cutout': _cutout_level_to_arg,
262
+ }
263
+
264
+
265
+ NAME_TO_OP = {
266
+ 'AutoContrast': auto_contrast,
267
+ 'Equalize': equalize,
268
+ 'Invert': invert,
269
+ 'Identity': identity,
270
+ 'Rotate': rotate,
271
+ 'PosterizeOriginal': posterize,
272
+ 'PosterizeResearch': posterize,
273
+ 'PosterizeTpu': posterize,
274
+ 'Solarize': solarize,
275
+ 'SolarizeAdd': solarize_add,
276
+ 'Color': color,
277
+ 'Contrast': contrast,
278
+ 'Brightness': brightness,
279
+ 'Sharpness': sharpness,
280
+ 'ShearX': shear_x,
281
+ 'ShearY': shear_y,
282
+ 'TranslateX': translate_x_abs,
283
+ 'TranslateY': translate_y_abs,
284
+ 'TranslateXRel': translate_x_rel,
285
+ 'TranslateYRel': translate_y_rel,
286
+ 'Cutout': cutout,
287
+ }
288
+
289
+
290
+ class AutoAugmentTransform(object):
291
+ """
292
+ AutoAugment from Google.
293
+ Implementation adapted from:
294
+ https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
295
+ """
296
+
297
+ def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
298
+ """
299
+ Args:
300
+ name (str): any type of transforms list in _RAND_TRANSFORMS.
301
+ prob (float): probability of perform current augmentation.
302
+ magnitude (int): intensity / magnitude of each augmentation.
303
+ hparams (dict): hyper-parameters required by each augmentation.
304
+ """
305
+ hparams = hparams or _HPARAMS_DEFAULT
306
+ self.aug_fn = NAME_TO_OP[name]
307
+ self.level_fn = LEVEL_TO_ARG[name]
308
+ self.prob = prob
309
+ self.magnitude = magnitude
310
+ self.hparams = hparams.copy()
311
+ self.kwargs = dict(
312
+ fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
313
+ resample=hparams['interpolation'] if 'interpolation' in hparams
314
+ else _RANDOM_INTERPOLATION,
315
+ )
316
+
317
+ # If magnitude_std is > 0, we introduce some randomness
318
+ # in the usually fixed policy and sample magnitude from a normal distribution
319
+ # with mean `magnitude` and std-dev of `magnitude_std`.
320
+ # NOTE This is my own hack, being tested, not in papers or reference impls.
321
+ self.magnitude_std = self.hparams.get('magnitude_std', 0)
322
+
323
+ def __call__(self, img: PIL.Image) -> PIL.Image:
324
+ if random.random() > self.prob:
325
+ return img
326
+ magnitude = self.magnitude
327
+ if self.magnitude_std and self.magnitude_std > 0:
328
+ magnitude = random.gauss(magnitude, self.magnitude_std)
329
+ # NOTE: magnitude fixed and no boundary
330
+ # magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
331
+ level_args = self.level_fn(
332
+ magnitude, self.hparams) if self.level_fn is not None else tuple()
333
+ return self.aug_fn(img, *level_args, **self.kwargs)
334
+ # return np.array(self.aug_fn(Image.fromarray(img), *level_args, **self.kwargs))
335
+
336
+ # def apply_coords(self, coords: np.ndarray) -> np.ndarray:
337
+ # return coords
338
+
339
+
340
+ _RAND_TRANSFORMS = [
341
+ 'AutoContrast',
342
+ 'Equalize',
343
+ 'Invert',
344
+ 'Rotate',
345
+ 'PosterizeTpu',
346
+ 'Solarize',
347
+ 'SolarizeAdd',
348
+ 'Color',
349
+ 'Contrast',
350
+ 'Brightness',
351
+ 'Sharpness',
352
+ 'ShearX',
353
+ 'ShearY',
354
+ 'TranslateXRel',
355
+ 'TranslateYRel',
356
+ 'Cutout' # FIXME I implement this as random erasing separately
357
+ ]
358
+
359
+ _RAND_TRANSFORMS_CMC = [
360
+ 'AutoContrast',
361
+ 'Identity',
362
+ 'Rotate',
363
+ 'Sharpness',
364
+ 'ShearX',
365
+ 'ShearY',
366
+ 'TranslateXRel',
367
+ 'TranslateYRel',
368
+ # 'Cutout' # FIXME I implement this as random erasing separately
369
+ ]
370
+
371
+
372
+ # These experimental weights are based loosely on the relative improvements mentioned in paper.
373
+ # They may not result in increased performance, but could likely be tuned to so.
374
+ _RAND_CHOICE_WEIGHTS_0 = {
375
+ 'Rotate': 0.3,
376
+ 'ShearX': 0.2,
377
+ 'ShearY': 0.2,
378
+ 'TranslateXRel': 0.1,
379
+ 'TranslateYRel': 0.1,
380
+ 'Color': .025,
381
+ 'Sharpness': 0.025,
382
+ 'AutoContrast': 0.025,
383
+ 'Solarize': .005,
384
+ 'SolarizeAdd': .005,
385
+ 'Contrast': .005,
386
+ 'Brightness': .005,
387
+ 'Equalize': .005,
388
+ 'PosterizeTpu': 0,
389
+ 'Invert': 0,
390
+ }
391
+
392
+
393
+ class RandAugPolicy(object):
394
+ def __init__(self, layers=2, magnitude=10):
395
+ self.layers = layers
396
+ self.magnitude = magnitude
397
+
398
+ def __call__(self, img):
399
+ for _ in range(self.layers):
400
+ trans = np.random.choice(_RAND_TRANSFORMS)
401
+ # NOTE: prob apply, fixed magnitude
402
+ # trans_op = AutoAugmentTransform(trans, prob=np.random.uniform(0.2, 0.8), magnitude=self.magnitude)
403
+ # NOTE: always apply, random magnitude
404
+ trans_op = AutoAugmentTransform(trans, prob=1.0, magnitude=np.random.choice(self.magnitude))
405
+ img = trans_op(img)
406
+ assert img is not None, trans
407
+ return img
RepVGG-main/utils.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf)
3
+ # Github source: https://github.com/DingXiaoH/RepVGG
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # The training script is based on the code of Swin Transformer (https://github.com/microsoft/Swin-Transformer)
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import math
10
+ import os
11
+
12
+ class AverageMeter(object):
13
+ """Computes and stores the average and current value"""
14
+ def __init__(self, name, fmt=':f'):
15
+ self.name = name
16
+ self.fmt = fmt
17
+ self.reset()
18
+
19
+ def reset(self):
20
+ self.val = 0
21
+ self.avg = 0
22
+ self.sum = 0
23
+ self.count = 0
24
+
25
+ def update(self, val, n=1):
26
+ self.val = val
27
+ self.sum += val * n
28
+ self.count += n
29
+ self.avg = self.sum / self.count
30
+
31
+ def __str__(self):
32
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
33
+ return fmtstr.format(**self.__dict__)
34
+
35
+
36
+ class ProgressMeter(object):
37
+ def __init__(self, num_batches, meters, prefix=""):
38
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
39
+ self.meters = meters
40
+ self.prefix = prefix
41
+
42
+ def display(self, batch):
43
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
44
+ entries += [str(meter) for meter in self.meters]
45
+ print('\t'.join(entries))
46
+
47
+ def _get_batch_fmtstr(self, num_batches):
48
+ num_digits = len(str(num_batches // 1))
49
+ fmt = '{:' + str(num_digits) + 'd}'
50
+ return '[' + fmt + '/' + fmt.format(num_batches) + ']'
51
+
52
+
53
+ def accuracy(output, target, topk=(1,)):
54
+ """Computes the accuracy over the k top predictions for the specified values of k"""
55
+ with torch.no_grad():
56
+ maxk = max(topk)
57
+ batch_size = target.size(0)
58
+
59
+ _, pred = output.topk(maxk, 1, True, True)
60
+ pred = pred.t()
61
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
62
+
63
+ res = []
64
+ for k in topk:
65
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
66
+ res.append(correct_k.mul_(100.0 / batch_size))
67
+ return res
68
+
69
+ def load_checkpoint(model, ckpt_path):
70
+ checkpoint = torch.load(ckpt_path)
71
+ if 'model' in checkpoint:
72
+ checkpoint = checkpoint['model']
73
+ if 'state_dict' in checkpoint:
74
+ checkpoint = checkpoint['state_dict']
75
+ ckpt = {}
76
+ for k, v in checkpoint.items():
77
+ if k.startswith('module.'):
78
+ ckpt[k[7:]] = v
79
+ else:
80
+ ckpt[k] = v
81
+ model.load_state_dict(ckpt)
82
+
83
+
84
+ class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler):
85
+
86
+ def __init__(self, optimizer, T_cosine_max, eta_min=0, last_epoch=-1, warmup=0):
87
+ self.eta_min = eta_min
88
+ self.T_cosine_max = T_cosine_max
89
+ self.warmup = warmup
90
+ super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch)
91
+
92
+ def get_lr(self):
93
+ if self.last_epoch < self.warmup:
94
+ return [self.last_epoch / self.warmup * base_lr for base_lr in self.base_lrs]
95
+ else:
96
+ return [self.eta_min + (base_lr - self.eta_min) *
97
+ (1 + math.cos(math.pi * (self.last_epoch - self.warmup) / (self.T_cosine_max - self.warmup))) / 2
98
+ for base_lr in self.base_lrs]
99
+
100
+
101
+ def log_msg(message, log_file):
102
+ print(message)
103
+ with open(log_file, 'a') as f:
104
+ print(message, file=f)
105
+
106
+
107
+
108
+
109
+
110
+ try:
111
+ # noinspection PyUnresolvedReferences
112
+ from apex import amp
113
+ except ImportError:
114
+ amp = None
115
+
116
+ def unwrap_model(model):
117
+ """Remove the DistributedDataParallel wrapper if present."""
118
+ wrapped = isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel)
119
+ return model.module if wrapped else model
120
+
121
+
122
+ def load_checkpoint(config, model, optimizer, lr_scheduler, logger, model_ema=None):
123
+ logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")
124
+ if config.MODEL.RESUME.startswith('https'):
125
+ checkpoint = torch.hub.load_state_dict_from_url(
126
+ config.MODEL.RESUME, map_location='cpu', check_hash=True)
127
+ else:
128
+ checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
129
+ msg = model.load_state_dict(checkpoint['model'], strict=False)
130
+ logger.info(msg)
131
+ max_accuracy = 0.0
132
+ if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
133
+ optimizer.load_state_dict(checkpoint['optimizer'])
134
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
135
+ config.defrost()
136
+ config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
137
+ config.freeze()
138
+ if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0":
139
+ amp.load_state_dict(checkpoint['amp'])
140
+ logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
141
+ if 'max_accuracy' in checkpoint:
142
+ max_accuracy = checkpoint['max_accuracy']
143
+ if model_ema is not None:
144
+ unwrap_model(model_ema).load_state_dict(checkpoint['ema'])
145
+ print('=================================================== EMAloaded')
146
+
147
+ del checkpoint
148
+ torch.cuda.empty_cache()
149
+ return max_accuracy
150
+
151
+
152
+ def load_weights(model, path):
153
+ checkpoint = torch.load(path, map_location='cpu')
154
+ if 'model' in checkpoint:
155
+ checkpoint = checkpoint['model']
156
+ if 'state_dict' in checkpoint:
157
+ checkpoint = checkpoint['state_dict']
158
+ unwrap_model(model).load_state_dict(checkpoint, strict=False)
159
+ print('=================== loaded from', path)
160
+
161
+ def save_latest(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger, model_ema=None):
162
+ save_state = {'model': model.state_dict(),
163
+ 'optimizer': optimizer.state_dict(),
164
+ 'lr_scheduler': lr_scheduler.state_dict(),
165
+ 'max_accuracy': max_accuracy,
166
+ 'epoch': epoch,
167
+ 'config': config}
168
+ if config.AMP_OPT_LEVEL != "O0":
169
+ save_state['amp'] = amp.state_dict()
170
+ if model_ema is not None:
171
+ save_state['ema'] = unwrap_model(model_ema).state_dict()
172
+
173
+ save_path = os.path.join(config.OUTPUT, 'latest.pth')
174
+ logger.info(f"{save_path} saving......")
175
+ torch.save(save_state, save_path)
176
+ logger.info(f"{save_path} saved !!!")
177
+
178
+ def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger, is_best=False, model_ema=None):
179
+ save_state = {'model': model.state_dict(),
180
+ 'optimizer': optimizer.state_dict(),
181
+ 'lr_scheduler': lr_scheduler.state_dict(),
182
+ 'max_accuracy': max_accuracy,
183
+ 'epoch': epoch,
184
+ 'config': config}
185
+ if config.AMP_OPT_LEVEL != "O0":
186
+ save_state['amp'] = amp.state_dict()
187
+ if model_ema is not None:
188
+ save_state['ema'] = unwrap_model(model_ema).state_dict()
189
+
190
+ if is_best:
191
+ best_path = os.path.join(config.OUTPUT, 'best_ckpt.pth')
192
+ torch.save(save_state, best_path)
193
+
194
+ save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
195
+ logger.info(f"{save_path} saving......")
196
+ torch.save(save_state, save_path)
197
+ logger.info(f"{save_path} saved !!!")
198
+
199
+
200
+ def get_grad_norm(parameters, norm_type=2):
201
+ if isinstance(parameters, torch.Tensor):
202
+ parameters = [parameters]
203
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
204
+ norm_type = float(norm_type)
205
+ total_norm = 0
206
+ for p in parameters:
207
+ param_norm = p.grad.data.norm(norm_type)
208
+ total_norm += param_norm.item() ** norm_type
209
+ total_norm = total_norm ** (1. / norm_type)
210
+ return total_norm
211
+
212
+
213
+ import torch.distributed as dist
214
+
215
+ def auto_resume_helper(output_dir):
216
+ checkpoints = os.listdir(output_dir)
217
+ checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth') and 'ema' not in ckpt]
218
+ print(f"All checkpoints founded in {output_dir}: {checkpoints}")
219
+ if len(checkpoints) > 0:
220
+ latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
221
+ print(f"The latest checkpoint founded: {latest_checkpoint}")
222
+ resume_file = latest_checkpoint
223
+ else:
224
+ resume_file = None
225
+ return resume_file
226
+
227
+
228
+ def reduce_tensor(tensor):
229
+ rt = tensor.clone()
230
+ dist.all_reduce(rt, op=dist.ReduceOp.SUM)
231
+ rt /= dist.get_world_size()
232
+ return rt
233
+
234
+ def update_model_ema(cfg, num_gpus, model, model_ema, cur_epoch, cur_iter):
235
+ """Update exponential moving average (ema) of model weights."""
236
+ update_period = cfg.TRAIN.EMA_UPDATE_PERIOD
237
+ if update_period is None or update_period == 0 or cur_iter % update_period != 0:
238
+ return
239
+ # Adjust alpha to be fairly independent of other parameters
240
+ total_batch_size = num_gpus * cfg.DATA.BATCH_SIZE
241
+ adjust = total_batch_size / cfg.TRAIN.EPOCHS * update_period
242
+ # print('ema adjust', adjust)
243
+ alpha = min(1.0, cfg.TRAIN.EMA_ALPHA * adjust)
244
+ # During warmup simply copy over weights instead of using ema
245
+ alpha = 1.0 if cur_epoch < cfg.TRAIN.WARMUP_EPOCHS else alpha
246
+ # Take ema of all parameters (not just named parameters)
247
+ params = unwrap_model(model).state_dict()
248
+ for name, param in unwrap_model(model_ema).state_dict().items():
249
+ param.copy_(param * (1.0 - alpha) + params[name] * alpha)
models/RepVGG-A0-train.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e538bfe8639d53a8cbeb4b580aac3dad8ecc304d71eddcb169f660a24fa80bb7
3
+ size 36588855
models/RepVGG-A1-train.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11cc973c2faf93b379205260a66bf3824a5abd206f420b9aaab94c0ede3c992a
3
+ size 56547389
models/RepVGG-A2-train.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f047695db0247758805d9051b2a3acd31a82e8e995291d28b969698fe262cc86
3
+ size 113071461
models/RepVGG-B0-train.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7c2d6ea9b3dfce017bead0eb3fff4f98e37b9e6a5aa15c2d5b75d4ac3f52bb0
3
+ size 63489053
models/RepVGG-B1-train.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fef08fb7dbe34966de425a4d1936c523d67bc05fadfe5228ea39514cbc62e1f5
3
+ size 230009223
models/RepVGG-B1g2-train.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:487c633c0161f48db2bd0aa698ef0b0b294dd7ff534305915bd2172cf1030ca9
3
+ size 183478388
models/RepVGG-B1g4-train.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70017ec47e9a0f52e589992774292aef87c65883a0b0e5e2c6b3b04c43d2a9f1
3
+ size 160213188
models/RepVGG-B2-train.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c603f0af6a388dad15b26cf452b4b0ff71be5f9e3f6d8e09259a804a2e20def
3
+ size 356505995
models/RepVGG-B2g4-200epochs-train.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cff67b5ec33d5a88f7d6b677d092bec97cb625e3dbea97c316bd3871f2617464
3
+ size 247450330
models/RepVGG-B2g4-train.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:347cfa7201be39656e6638b0d725edfb0d2d828e2eeac6fadde77519f9ea68bc
3
+ size 247450480
models/RepVGG-B3-200epochs-train.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ddb6ba07d090d7369486e7c90a2ebbd706af399cf40c619812fa52f6d91d928d
3
+ size 492817293
models/RepVGG-B3g4-200epochs-train.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98e166bd59d7ce972058537363bad91e2295caab58fbef686495d79d557aaf22
3
+ size 335776986
models/RepVGG-D2se-200epochs-train.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91b227e636de504a1549adac2f8f1d554c9a121e10c0287dd30e9ad90e481cab
3
+ size 534046654
models/RepVGGplus-L2pse-train-custom-wd-acc84.16.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b6fbe09cb9968996a3c8e22e77facf1eae110bbfe5681910eafeabc5293b55c
3
+ size 585161602
models/RepVGGplus-L2pse-train256-acc84.06.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:367ef8381dd6afd7565027c07248096fae7b594946b5c2b122df574a849346f8
3
+ size 585148392