Upload 9 files
Browse files- .gitignore +112 -0
- LICENSE +21 -0
- README.md +149 -3
- config.json +57 -0
- inference.py +116 -0
- parse_config.py +163 -0
- requirements.txt +7 -0
- test.py +77 -0
- train.py +85 -0
.gitignore
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
env/
|
| 12 |
+
build/
|
| 13 |
+
develop-eggs/
|
| 14 |
+
dist/
|
| 15 |
+
downloads/
|
| 16 |
+
eggs/
|
| 17 |
+
.eggs/
|
| 18 |
+
lib/
|
| 19 |
+
lib64/
|
| 20 |
+
parts/
|
| 21 |
+
sdist/
|
| 22 |
+
var/
|
| 23 |
+
wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
|
| 28 |
+
# PyInstaller
|
| 29 |
+
# Usually these files are written by a python script from a template
|
| 30 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 31 |
+
*.manifest
|
| 32 |
+
*.spec
|
| 33 |
+
|
| 34 |
+
# Installer logs
|
| 35 |
+
pip-log.txt
|
| 36 |
+
pip-delete-this-directory.txt
|
| 37 |
+
|
| 38 |
+
# Unit test / coverage reports
|
| 39 |
+
htmlcov/
|
| 40 |
+
.tox/
|
| 41 |
+
.coverage
|
| 42 |
+
.coverage.*
|
| 43 |
+
.cache
|
| 44 |
+
nosetests.xml
|
| 45 |
+
coverage.xml
|
| 46 |
+
*.cover
|
| 47 |
+
.hypothesis/
|
| 48 |
+
|
| 49 |
+
# Translations
|
| 50 |
+
*.mo
|
| 51 |
+
*.pot
|
| 52 |
+
|
| 53 |
+
# Django stuff:
|
| 54 |
+
*.log
|
| 55 |
+
local_settings.py
|
| 56 |
+
|
| 57 |
+
# Flask stuff:
|
| 58 |
+
instance/
|
| 59 |
+
.webassets-cache
|
| 60 |
+
|
| 61 |
+
# Scrapy stuff:
|
| 62 |
+
.scrapy
|
| 63 |
+
|
| 64 |
+
# Sphinx documentation
|
| 65 |
+
docs/_build/
|
| 66 |
+
|
| 67 |
+
# PyBuilder
|
| 68 |
+
target/
|
| 69 |
+
|
| 70 |
+
# Jupyter Notebook
|
| 71 |
+
.ipynb_checkpoints
|
| 72 |
+
|
| 73 |
+
# pyenv
|
| 74 |
+
.python-version
|
| 75 |
+
|
| 76 |
+
# celery beat schedule file
|
| 77 |
+
celerybeat-schedule
|
| 78 |
+
|
| 79 |
+
# SageMath parsed files
|
| 80 |
+
*.sage.py
|
| 81 |
+
|
| 82 |
+
# dotenv
|
| 83 |
+
.env
|
| 84 |
+
|
| 85 |
+
# virtualenv
|
| 86 |
+
.venv
|
| 87 |
+
venv/
|
| 88 |
+
ENV/
|
| 89 |
+
|
| 90 |
+
# Spyder project settings
|
| 91 |
+
.spyderproject
|
| 92 |
+
.spyproject
|
| 93 |
+
|
| 94 |
+
# Rope project settings
|
| 95 |
+
.ropeproject
|
| 96 |
+
|
| 97 |
+
# mkdocs documentation
|
| 98 |
+
/site
|
| 99 |
+
|
| 100 |
+
# mypy
|
| 101 |
+
.mypy_cache/
|
| 102 |
+
|
| 103 |
+
# input data, saved log, checkpoints
|
| 104 |
+
data/
|
| 105 |
+
input/
|
| 106 |
+
saved/
|
| 107 |
+
datasets/
|
| 108 |
+
|
| 109 |
+
# editor, os cache directory
|
| 110 |
+
.vscode/
|
| 111 |
+
.idea/
|
| 112 |
+
__MACOSX/
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 Technische Universität Kaiserslautern (TUK) & National University of Sciences and Technology (NUST)
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,3 +1,149 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# NUST Wheat Rust Disease (NWRD): Semantic Segmentation using Suppervised Deep Learning
|
| 2 |
+
Semantic segmentation of wheat yellow/stripe rust disease images to segment out rust and non-rust pixels using supervised deep learning.
|
| 3 |
+
|
| 4 |
+
This repo contains the source code for the study presented in the work:
|
| 5 |
+
`The NWRD Dataset: An Open-Source Annotated Segmentation Dataset of Diseased Wheat Crop`.
|
| 6 |
+
|
| 7 |
+
## Dataset
|
| 8 |
+
The NWRD dataset is a real-world segmentation dataset of wheat rust diseased and healthy leaf images specifically constructed for semantic segmentation of wheat rust disease.
|
| 9 |
+
The NWRD dataset consists of 100 images in total at this moment.
|
| 10 |
+
|
| 11 |
+
Sample images from The NWRD dataset; annotated images showing rust disease along with their binary masks:
|
| 12 |
+

|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
Dataset is available at: https://dll.seecs.nust.edu.pk/downloads/
|
| 16 |
+
|
| 17 |
+
### Directory Structure
|
| 18 |
+
The NWRD dataset images are available in `.jpg` format and the annotated binary masks are available in `.png` format. Below is the directory structure of this dataset:
|
| 19 |
+
```
|
| 20 |
+
NWRD
|
| 21 |
+
├── test
|
| 22 |
+
│ ├── images
|
| 23 |
+
│ └── masks
|
| 24 |
+
└── train
|
| 25 |
+
├── images
|
| 26 |
+
└── masks
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
### Data Splits
|
| 30 |
+
Here are the data splits of the NWRD dataset:
|
| 31 |
+
| Split | Percentage |
|
| 32 |
+
|-----------|-----------|
|
| 33 |
+
| Train + Valid | 90 |
|
| 34 |
+
| Test | 10 |
|
| 35 |
+
| Total | 100 |
|
| 36 |
+
|
| 37 |
+
The experimentation with 22 images was done with the following set of images:
|
| 38 |
+
| Split | Images |
|
| 39 |
+
|-----------|-----------|
|
| 40 |
+
| Train + Valid | 2, 7, 14, 30, 64, 83, 84, 90, 94, 95, 100, 118, 124, 125, 132, 133, 136, 137, 138, 146 |
|
| 41 |
+
| Test | 67, 123 |
|
| 42 |
+
|
| 43 |
+
## Requirements
|
| 44 |
+
* Python
|
| 45 |
+
* Pytorch
|
| 46 |
+
* Torchvision
|
| 47 |
+
* Numpy
|
| 48 |
+
* Tqdm
|
| 49 |
+
* Tensorboard
|
| 50 |
+
* Scikit-learn
|
| 51 |
+
* Pandas
|
| 52 |
+
|
| 53 |
+
## Usage
|
| 54 |
+
Different aspects of the training can be tunned from the `config.json` file.
|
| 55 |
+
Use `python train.py -c config.json` to run code.
|
| 56 |
+
|
| 57 |
+
### Config file format
|
| 58 |
+
Config files are in `.json` format:
|
| 59 |
+
```javascript
|
| 60 |
+
{
|
| 61 |
+
"name": "WRS", // training session name
|
| 62 |
+
"n_gpu": 1, // number of GPUs to use for training.
|
| 63 |
+
|
| 64 |
+
"arch": {
|
| 65 |
+
"type": "UNet", // name of model architecture to train
|
| 66 |
+
"args": {
|
| 67 |
+
"n_channels": 3,
|
| 68 |
+
"n_classes": 2
|
| 69 |
+
} // pass arguments to the model
|
| 70 |
+
},
|
| 71 |
+
"data_loader": {
|
| 72 |
+
"type": "PatchedDataLoader", // selecting data loader
|
| 73 |
+
"args":{
|
| 74 |
+
"data_dir": "data/", // dataset path
|
| 75 |
+
"patch_size": 128, // patch size
|
| 76 |
+
"batch_size": 64, // batch size
|
| 77 |
+
"patch_stride": 32, // patch overlapping stride
|
| 78 |
+
"target_dist": 0.01, // least percentage of rust pixels in a patch
|
| 79 |
+
"shuffle": true, // shuffle training data before
|
| 80 |
+
"validation_split": 0.1, // size of validation dataset. float(portion) or int(number of samples)
|
| 81 |
+
"num_workers": 2 // number of cpu processes to be used for data loading
|
| 82 |
+
}
|
| 83 |
+
},
|
| 84 |
+
"optimizer": {
|
| 85 |
+
"type": "RMSprop",
|
| 86 |
+
"args":{
|
| 87 |
+
"lr": 1e-6, // learning rate
|
| 88 |
+
"weight_decay": 0
|
| 89 |
+
}
|
| 90 |
+
},
|
| 91 |
+
"loss": "focal_loss", // loss function
|
| 92 |
+
"metrics": [ // list of metrics to evaluate
|
| 93 |
+
"precision",
|
| 94 |
+
"recall",
|
| 95 |
+
"f1_score"
|
| 96 |
+
],
|
| 97 |
+
"lr_scheduler": {
|
| 98 |
+
"type": "ExponentialLR", // learning rate scheduler
|
| 99 |
+
"args": {
|
| 100 |
+
"gamma": 0.998
|
| 101 |
+
}
|
| 102 |
+
},
|
| 103 |
+
"trainer": {
|
| 104 |
+
"epochs": 500, // number of training epochs
|
| 105 |
+
"adaptive_step": 5, // update dataset after every adaptive_step epochs
|
| 106 |
+
|
| 107 |
+
"save_dir": "saved/",
|
| 108 |
+
"save_period": 1, // save checkpoints every save_period epochs
|
| 109 |
+
"verbosity": 2, // 0: quiet, 1: per epoch, 2: full
|
| 110 |
+
|
| 111 |
+
"monitor": "min val_loss", // mode and metric for model performance monitoring. set 'off' to disable.
|
| 112 |
+
"early_stop": 50, // early stop
|
| 113 |
+
|
| 114 |
+
"tensorboard": true // enable tensorboard visualization
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
### Using config files
|
| 120 |
+
Modify the configurations in `.json` config files, then run:
|
| 121 |
+
|
| 122 |
+
```
|
| 123 |
+
python train.py --config config.json
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
### Resuming from checkpoints
|
| 127 |
+
You can resume from a previously saved checkpoint by:
|
| 128 |
+
|
| 129 |
+
```
|
| 130 |
+
python train.py --resume path/to/checkpoint
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
### Using Multiple GPU
|
| 134 |
+
You can enable multi-GPU training by setting `n_gpu` argument of the config file to larger number.
|
| 135 |
+
If configured to use smaller number of gpu than available, first n devices will be used by default.
|
| 136 |
+
Specify indices of available GPUs by cuda environmental variable.
|
| 137 |
+
```
|
| 138 |
+
python train.py --device 2,3 -c config.json
|
| 139 |
+
```
|
| 140 |
+
This is equivalent to
|
| 141 |
+
```
|
| 142 |
+
CUDA_VISIBLE_DEVICES=2,3 python train.py -c config.py
|
| 143 |
+
```
|
| 144 |
+
## License
|
| 145 |
+
This project is licensed under the MIT License. See [LICENSE](LICENSE) for more details
|
| 146 |
+
|
| 147 |
+
## Acknowledgements
|
| 148 |
+
* This research project was funded by German Academic Exchange Service (DAAD).
|
| 149 |
+
* This project follows the template provided by [victoresque](https://github.com/victoresque/pytorch-template)
|
config.json
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "WRS_Adaptive_Patching",
|
| 3 |
+
"n_gpu": 1,
|
| 4 |
+
|
| 5 |
+
"arch": {
|
| 6 |
+
"type": "UNet",
|
| 7 |
+
"args": {
|
| 8 |
+
"n_channels": 3,
|
| 9 |
+
"n_classes": 2
|
| 10 |
+
}
|
| 11 |
+
},
|
| 12 |
+
"data_loader": {
|
| 13 |
+
"type": "PatchedDataLoader",
|
| 14 |
+
"args":{
|
| 15 |
+
"data_dir": "/scratch/sukhan/wrs/datads-cv/",
|
| 16 |
+
"patch_size": 128,
|
| 17 |
+
"batch_size": 64,
|
| 18 |
+
"patch_stride": 32,
|
| 19 |
+
"target_dist": 0.01,
|
| 20 |
+
"shuffle": true,
|
| 21 |
+
"validation_split": 0.1,
|
| 22 |
+
"num_workers": 8
|
| 23 |
+
}
|
| 24 |
+
},
|
| 25 |
+
"optimizer": {
|
| 26 |
+
"type": "RMSprop",
|
| 27 |
+
"args":{
|
| 28 |
+
"lr": 1e-6,
|
| 29 |
+
"weight_decay": 0
|
| 30 |
+
}
|
| 31 |
+
},
|
| 32 |
+
"loss": "focal_loss",
|
| 33 |
+
"metrics": [
|
| 34 |
+
"precision",
|
| 35 |
+
"recall",
|
| 36 |
+
"f1_score"
|
| 37 |
+
],
|
| 38 |
+
"lr_scheduler": {
|
| 39 |
+
"type": "ExponentialLR",
|
| 40 |
+
"args": {
|
| 41 |
+
"gamma": 0.998
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"trainer": {
|
| 45 |
+
"epochs": 500,
|
| 46 |
+
"adaptive_step": 5,
|
| 47 |
+
|
| 48 |
+
"save_dir": "/scratch/sukhan/wrs/saved/",
|
| 49 |
+
"save_period": 1,
|
| 50 |
+
"verbosity": 2,
|
| 51 |
+
|
| 52 |
+
"monitor": "min val_loss",
|
| 53 |
+
"early_stop": 50,
|
| 54 |
+
|
| 55 |
+
"tensorboard": true
|
| 56 |
+
}
|
| 57 |
+
}
|
inference.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import collections
|
| 3 |
+
import glob
|
| 4 |
+
import shutil
|
| 5 |
+
import sys
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import PIL.Image as Image
|
| 9 |
+
import torch
|
| 10 |
+
from torchvision.transforms import transforms
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import data_loader.data_loaders as module_data
|
| 13 |
+
import model.model as module_arch
|
| 14 |
+
from parse_config import ConfigParser
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def main(config):
|
| 18 |
+
logger = config.get_logger('infernece')
|
| 19 |
+
|
| 20 |
+
# setup data_loader instances
|
| 21 |
+
data_loader = getattr(module_data, config['data_loader']['type'])(
|
| 22 |
+
config['data_loader']['args']['data_dir'],
|
| 23 |
+
patch_size=config['data_loader']['args']['patch_size'],
|
| 24 |
+
batch_size=512,
|
| 25 |
+
shuffle=False,
|
| 26 |
+
validation_split=0.0,
|
| 27 |
+
training=False,
|
| 28 |
+
num_workers=2
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# build model architecture
|
| 32 |
+
model = config.init_obj('arch', module_arch)
|
| 33 |
+
logger.info(model)
|
| 34 |
+
|
| 35 |
+
# prepare model for testing
|
| 36 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 37 |
+
logger.info('Loading checkpoint: {} ...'.format(config.resume))
|
| 38 |
+
checkpoint = torch.load(config.resume, map_location=device)
|
| 39 |
+
state_dict = checkpoint['state_dict']
|
| 40 |
+
if config['n_gpu'] > 1:
|
| 41 |
+
model = torch.nn.DataParallel(model)
|
| 42 |
+
model.load_state_dict(state_dict)
|
| 43 |
+
model = model.to(device)
|
| 44 |
+
|
| 45 |
+
model.eval()
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
for i, (data, target) in enumerate(tqdm(data_loader)):
|
| 48 |
+
data, target = data.to(device), target.to(device)
|
| 49 |
+
output = model(data)
|
| 50 |
+
|
| 51 |
+
batch_size = data_loader.batch_size
|
| 52 |
+
patch_idx = torch.arange(
|
| 53 |
+
batch_size * i, batch_size * i + data.shape[0])
|
| 54 |
+
pred = torch.argmax(output, dim=1)
|
| 55 |
+
data_loader.dataset.patches.store_data(
|
| 56 |
+
patch_idx, [pred.unsqueeze(1)])
|
| 57 |
+
|
| 58 |
+
preds = [(data_loader.dataset.patches.combine(idx, data_idx=0).cpu(),
|
| 59 |
+
data_loader.dataset.data[idx])
|
| 60 |
+
for idx in range(len(data_loader.dataset.data))]
|
| 61 |
+
trsfm = transforms.ToPILImage()
|
| 62 |
+
|
| 63 |
+
out_dir = list(config.save_dir.parts)
|
| 64 |
+
out_dir[-3] = 'output'
|
| 65 |
+
out_dir = Path(*out_dir)
|
| 66 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 67 |
+
for pred, path in preds:
|
| 68 |
+
filename = Path(path).stem + '.png'
|
| 69 |
+
pred = trsfm(pred.float())
|
| 70 |
+
pred.save(out_dir / filename)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
if __name__ == '__main__':
|
| 74 |
+
args = argparse.ArgumentParser(description='PyTorch Template')
|
| 75 |
+
args.add_argument('-c', '--config', default=None, type=str,
|
| 76 |
+
help='config file path (default: None)')
|
| 77 |
+
args.add_argument('-r', '--resume', default=None, type=str,
|
| 78 |
+
help='path to latest checkpoint (default: None)')
|
| 79 |
+
args.add_argument('-d', '--device', default=None, type=str,
|
| 80 |
+
help='indices of GPUs to enable (default: all)')
|
| 81 |
+
args.add_argument('--data', default=None, type=str,
|
| 82 |
+
help='path to data (default: None)')
|
| 83 |
+
|
| 84 |
+
run_id = datetime.now().strftime(r'%m%d_%H%M%S')
|
| 85 |
+
dst_data = Path('.data/', run_id)
|
| 86 |
+
|
| 87 |
+
data_dir = dst_data / 'test' / 'images'
|
| 88 |
+
masks_dir = dst_data / 'test' / 'masks'
|
| 89 |
+
|
| 90 |
+
data_dir.mkdir(parents=True, exist_ok=True)
|
| 91 |
+
masks_dir.mkdir(parents=True, exist_ok=True)
|
| 92 |
+
|
| 93 |
+
if '--data' in sys.argv:
|
| 94 |
+
src_data = Path(sys.argv[sys.argv.index('--data') + 1])
|
| 95 |
+
if src_data.is_file():
|
| 96 |
+
shutil.copy(src_data, data_dir)
|
| 97 |
+
mask = Image.new('1', Image.open(src_data).size)
|
| 98 |
+
mask.save(masks_dir / (src_data.stem + '.png'), 'PNG')
|
| 99 |
+
else:
|
| 100 |
+
for filename in glob.glob('*.jpg', root_dir=src_data):
|
| 101 |
+
file_path = src_data / filename
|
| 102 |
+
if file_path.is_file():
|
| 103 |
+
shutil.copy(file_path, data_dir)
|
| 104 |
+
mask = Image.new('1', Image.open(file_path).size)
|
| 105 |
+
mask.save(masks_dir / (file_path.stem + '.png'), 'PNG')
|
| 106 |
+
sys.argv += ['--data_dir', str(dst_data)]
|
| 107 |
+
|
| 108 |
+
# custom cli options to modify configuration from default values given in json file.
|
| 109 |
+
CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
|
| 110 |
+
options = [
|
| 111 |
+
CustomArgs(['--data_dir'], type=str,
|
| 112 |
+
target='data_loader;args;data_dir')
|
| 113 |
+
]
|
| 114 |
+
config = ConfigParser.from_args(args, options)
|
| 115 |
+
main(config)
|
| 116 |
+
shutil.rmtree(dst_data)
|
parse_config.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from functools import partial, reduce
|
| 5 |
+
from operator import getitem
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from logger import setup_logging
|
| 8 |
+
from utils import read_json, write_json
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ConfigParser:
|
| 12 |
+
def __init__(self, config, resume=None, modification=None, run_id=None):
|
| 13 |
+
"""
|
| 14 |
+
class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving
|
| 15 |
+
and logging module.
|
| 16 |
+
:param config: Dict containing configurations, hyperparameters for training. contents of `config.json` file for example.
|
| 17 |
+
:param resume: String, path to the checkpoint being loaded.
|
| 18 |
+
:param modification: Dict keychain:value, specifying position values to be replaced from config dict.
|
| 19 |
+
:param run_id: Unique Identifier for training processes. Used to save checkpoints and training log. Timestamp is being used as default
|
| 20 |
+
"""
|
| 21 |
+
# load config file and apply modification
|
| 22 |
+
self._config = _update_config(config, modification)
|
| 23 |
+
self.resume = resume
|
| 24 |
+
|
| 25 |
+
# set save_dir where trained model and log will be saved.
|
| 26 |
+
save_dir = Path(self.config['trainer']['save_dir'])
|
| 27 |
+
|
| 28 |
+
exper_name = self.config['name']
|
| 29 |
+
if run_id is None: # use timestamp as default run-id
|
| 30 |
+
run_id = datetime.now().strftime(r'%m%d_%H%M%S')
|
| 31 |
+
self._save_dir = save_dir / 'models' / exper_name / run_id
|
| 32 |
+
self._log_dir = save_dir / 'log' / exper_name / run_id
|
| 33 |
+
|
| 34 |
+
# make directory for saving checkpoints and log.
|
| 35 |
+
exist_ok = run_id == ''
|
| 36 |
+
self.save_dir.mkdir(parents=True, exist_ok=exist_ok)
|
| 37 |
+
self.log_dir.mkdir(parents=True, exist_ok=exist_ok)
|
| 38 |
+
|
| 39 |
+
# save updated config file to the checkpoint dir
|
| 40 |
+
write_json(self.config, self.save_dir / 'config.json')
|
| 41 |
+
|
| 42 |
+
# configure logging module
|
| 43 |
+
setup_logging(self.log_dir)
|
| 44 |
+
self.log_levels = {
|
| 45 |
+
0: logging.WARNING,
|
| 46 |
+
1: logging.INFO,
|
| 47 |
+
2: logging.DEBUG
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
@classmethod
|
| 51 |
+
def from_args(cls, args, options=''):
|
| 52 |
+
"""
|
| 53 |
+
Initialize this class from some cli arguments. Used in train, test.
|
| 54 |
+
"""
|
| 55 |
+
for opt in options:
|
| 56 |
+
args.add_argument(*opt.flags, default=None, type=opt.type)
|
| 57 |
+
if not isinstance(args, tuple):
|
| 58 |
+
args = args.parse_args()
|
| 59 |
+
|
| 60 |
+
if args.device is not None:
|
| 61 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
|
| 62 |
+
if args.resume is not None:
|
| 63 |
+
resume = Path(args.resume)
|
| 64 |
+
cfg_fname = resume.parent / 'config.json'
|
| 65 |
+
else:
|
| 66 |
+
msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
|
| 67 |
+
assert args.config is not None, msg_no_cfg
|
| 68 |
+
resume = None
|
| 69 |
+
cfg_fname = Path(args.config)
|
| 70 |
+
|
| 71 |
+
config = read_json(cfg_fname)
|
| 72 |
+
if args.config and resume:
|
| 73 |
+
# update new config for fine-tuning
|
| 74 |
+
config.update(read_json(args.config))
|
| 75 |
+
|
| 76 |
+
# parse custom cli options into dictionary
|
| 77 |
+
modification = {opt.target: getattr(args, _get_opt_name(opt.flags)) for opt in options}
|
| 78 |
+
return cls(config, resume, modification)
|
| 79 |
+
|
| 80 |
+
def init_obj(self, name, module, *args, **kwargs):
|
| 81 |
+
"""
|
| 82 |
+
Finds a function handle with the name given as 'type' in config, and returns the
|
| 83 |
+
instance initialized with corresponding arguments given.
|
| 84 |
+
|
| 85 |
+
`object = config.init_obj('name', module, a, b=1)`
|
| 86 |
+
is equivalent to
|
| 87 |
+
`object = module.name(a, b=1)`
|
| 88 |
+
"""
|
| 89 |
+
module_name = self[name]['type']
|
| 90 |
+
module_args = dict(self[name]['args'])
|
| 91 |
+
assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
|
| 92 |
+
module_args.update(kwargs)
|
| 93 |
+
return getattr(module, module_name)(*args, **module_args)
|
| 94 |
+
|
| 95 |
+
def init_ftn(self, name, module, *args, **kwargs):
|
| 96 |
+
"""
|
| 97 |
+
Finds a function handle with the name given as 'type' in config, and returns the
|
| 98 |
+
function with given arguments fixed with functools.partial.
|
| 99 |
+
|
| 100 |
+
`function = config.init_ftn('name', module, a, b=1)`
|
| 101 |
+
is equivalent to
|
| 102 |
+
`function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`.
|
| 103 |
+
"""
|
| 104 |
+
module_name = self[name]['type']
|
| 105 |
+
module_args = dict(self[name]['args'])
|
| 106 |
+
assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
|
| 107 |
+
module_args.update(kwargs)
|
| 108 |
+
return partial(getattr(module, module_name), *args, **module_args)
|
| 109 |
+
|
| 110 |
+
def __getitem__(self, name):
|
| 111 |
+
"""Access items like ordinary dict."""
|
| 112 |
+
return self.config[name]
|
| 113 |
+
|
| 114 |
+
def get_logger(self, name, verbosity=2):
|
| 115 |
+
msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(
|
| 116 |
+
verbosity, self.log_levels.keys())
|
| 117 |
+
assert verbosity in self.log_levels, msg_verbosity
|
| 118 |
+
logger = logging.getLogger(name)
|
| 119 |
+
logger.setLevel(self.log_levels[verbosity])
|
| 120 |
+
return logger
|
| 121 |
+
|
| 122 |
+
# setting read-only attributes
|
| 123 |
+
@property
|
| 124 |
+
def config(self):
|
| 125 |
+
return self._config
|
| 126 |
+
|
| 127 |
+
@property
|
| 128 |
+
def save_dir(self):
|
| 129 |
+
return self._save_dir
|
| 130 |
+
|
| 131 |
+
@property
|
| 132 |
+
def log_dir(self):
|
| 133 |
+
return self._log_dir
|
| 134 |
+
|
| 135 |
+
# helper functions to update config dict with custom cli options
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _update_config(config, modification):
|
| 139 |
+
if modification is None:
|
| 140 |
+
return config
|
| 141 |
+
|
| 142 |
+
for k, v in modification.items():
|
| 143 |
+
if v is not None:
|
| 144 |
+
_set_by_path(config, k, v)
|
| 145 |
+
return config
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _get_opt_name(flags):
|
| 149 |
+
for flg in flags:
|
| 150 |
+
if flg.startswith('--'):
|
| 151 |
+
return flg.replace('--', '')
|
| 152 |
+
return flags[0].replace('--', '')
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _set_by_path(tree, keys, value):
|
| 156 |
+
"""Set a value in a nested object in tree by sequence of keys."""
|
| 157 |
+
keys = keys.split(';')
|
| 158 |
+
_get_by_path(tree, keys[:-1])[keys[-1]] = value
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def _get_by_path(tree, keys):
|
| 162 |
+
"""Access a nested object in tree by sequence of keys."""
|
| 163 |
+
return reduce(getitem, keys, tree)
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=1.12
|
| 2 |
+
torchvision
|
| 3 |
+
numpy
|
| 4 |
+
tqdm
|
| 5 |
+
tensorboard>=2.9.1
|
| 6 |
+
scikit-learn
|
| 7 |
+
pandas
|
test.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import torch
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import data_loader.data_loaders as module_data
|
| 5 |
+
import model.loss as module_loss
|
| 6 |
+
import model.metric as module_metric
|
| 7 |
+
import model.model as module_arch
|
| 8 |
+
from parse_config import ConfigParser
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def main(config):
|
| 12 |
+
logger = config.get_logger('test')
|
| 13 |
+
|
| 14 |
+
# setup data_loader instances
|
| 15 |
+
data_loader = getattr(module_data, config['data_loader']['type'])(
|
| 16 |
+
config['data_loader']['args']['data_dir'],
|
| 17 |
+
patch_size=config['data_loader']['args']['patch_size'],
|
| 18 |
+
batch_size=512,
|
| 19 |
+
shuffle=False,
|
| 20 |
+
validation_split=0.0,
|
| 21 |
+
training=False,
|
| 22 |
+
num_workers=2
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# build model architecture
|
| 26 |
+
model = config.init_obj('arch', module_arch)
|
| 27 |
+
logger.info(model)
|
| 28 |
+
|
| 29 |
+
# prepare model for testing
|
| 30 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 31 |
+
logger.info('Loading checkpoint: {} ...'.format(config.resume))
|
| 32 |
+
checkpoint = torch.load(config.resume, map_location=device)
|
| 33 |
+
state_dict = checkpoint['state_dict']
|
| 34 |
+
if config['n_gpu'] > 1:
|
| 35 |
+
model = torch.nn.DataParallel(model)
|
| 36 |
+
model.load_state_dict(state_dict)
|
| 37 |
+
model = model.to(device)
|
| 38 |
+
|
| 39 |
+
# get function handles of loss and metrics
|
| 40 |
+
loss_fn = getattr(module_loss, config['loss'])
|
| 41 |
+
metric_fns = [getattr(module_metric, met) for met in config['metrics']]
|
| 42 |
+
|
| 43 |
+
total_loss = 0.0
|
| 44 |
+
total_metrics = torch.zeros(len(metric_fns))
|
| 45 |
+
|
| 46 |
+
model.eval()
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
for i, (data, target) in enumerate(tqdm(data_loader)):
|
| 49 |
+
data, target = data.to(device), target.to(device)
|
| 50 |
+
output = model(data)
|
| 51 |
+
|
| 52 |
+
# computing loss, metrics on test set
|
| 53 |
+
loss = loss_fn(output, target)
|
| 54 |
+
batch_size = data.shape[0]
|
| 55 |
+
total_loss += loss.item() * batch_size
|
| 56 |
+
for i, metric in enumerate(metric_fns):
|
| 57 |
+
total_metrics[i] += metric(output, target) * batch_size
|
| 58 |
+
|
| 59 |
+
n_samples = len(data_loader.sampler)
|
| 60 |
+
log = {'loss': total_loss / n_samples}
|
| 61 |
+
log.update({
|
| 62 |
+
met.__name__: total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns)
|
| 63 |
+
})
|
| 64 |
+
logger.info(log)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
if __name__ == '__main__':
|
| 68 |
+
args = argparse.ArgumentParser(description='PyTorch Template')
|
| 69 |
+
args.add_argument('-c', '--config', default=None, type=str,
|
| 70 |
+
help='config file path (default: None)')
|
| 71 |
+
args.add_argument('-r', '--resume', default=None, type=str,
|
| 72 |
+
help='path to latest checkpoint (default: None)')
|
| 73 |
+
args.add_argument('-d', '--device', default=None, type=str,
|
| 74 |
+
help='indices of GPUs to enable (default: all)')
|
| 75 |
+
|
| 76 |
+
config = ConfigParser.from_args(args)
|
| 77 |
+
main(config)
|
train.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import collections
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import data_loader.data_loaders as module_data
|
| 6 |
+
import model.loss as module_loss
|
| 7 |
+
import model.metric as module_metric
|
| 8 |
+
import model.model as module_arch
|
| 9 |
+
from parse_config import ConfigParser
|
| 10 |
+
from trainer import Trainer
|
| 11 |
+
from utils import prepare_device
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# fix random seeds for reproducibility
|
| 15 |
+
SEED = 123
|
| 16 |
+
torch.manual_seed(SEED)
|
| 17 |
+
torch.backends.cudnn.deterministic = True
|
| 18 |
+
torch.backends.cudnn.benchmark = False
|
| 19 |
+
np.random.seed(SEED)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def main(config):
|
| 23 |
+
logger = config.get_logger('train')
|
| 24 |
+
|
| 25 |
+
# setup data_loader instances
|
| 26 |
+
data_loader = config.init_obj('data_loader', module_data)
|
| 27 |
+
valid_data_loader = data_loader.split_validation()
|
| 28 |
+
|
| 29 |
+
# setup data_loader for inference
|
| 30 |
+
data_loader.inference = getattr(module_data, config['data_loader']['type'])(
|
| 31 |
+
config['data_loader']['args']['data_dir'],
|
| 32 |
+
patch_size=config['data_loader']['args']['patch_size'],
|
| 33 |
+
batch_size=config['data_loader']['args']['batch_size'],
|
| 34 |
+
shuffle=False,
|
| 35 |
+
validation_split=0.0,
|
| 36 |
+
training=True,
|
| 37 |
+
num_workers=config['data_loader']['args']['num_workers']
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# build model architecture, then print to console
|
| 41 |
+
model = config.init_obj('arch', module_arch)
|
| 42 |
+
logger.info(model)
|
| 43 |
+
|
| 44 |
+
# prepare for (multi-device) GPU training
|
| 45 |
+
device, device_ids = prepare_device(config['n_gpu'])
|
| 46 |
+
model = model.to(device)
|
| 47 |
+
if len(device_ids) > 1:
|
| 48 |
+
model = torch.nn.DataParallel(model, device_ids=device_ids)
|
| 49 |
+
|
| 50 |
+
# get function handles of loss and metrics
|
| 51 |
+
criterion = getattr(module_loss, config['loss'])
|
| 52 |
+
metrics = [getattr(module_metric, met) for met in config['metrics']]
|
| 53 |
+
|
| 54 |
+
# build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
|
| 55 |
+
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
|
| 56 |
+
optimizer = config.init_obj('optimizer', torch.optim, trainable_params)
|
| 57 |
+
lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)
|
| 58 |
+
|
| 59 |
+
trainer = Trainer(model, criterion, metrics, optimizer,
|
| 60 |
+
config=config,
|
| 61 |
+
device=device,
|
| 62 |
+
data_loader=data_loader,
|
| 63 |
+
valid_data_loader=valid_data_loader,
|
| 64 |
+
lr_scheduler=lr_scheduler)
|
| 65 |
+
|
| 66 |
+
trainer.train()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
if __name__ == '__main__':
|
| 70 |
+
args = argparse.ArgumentParser(description='PyTorch Template')
|
| 71 |
+
args.add_argument('-c', '--config', default=None, type=str,
|
| 72 |
+
help='config file path (default: None)')
|
| 73 |
+
args.add_argument('-r', '--resume', default=None, type=str,
|
| 74 |
+
help='path to latest checkpoint (default: None)')
|
| 75 |
+
args.add_argument('-d', '--device', default=None, type=str,
|
| 76 |
+
help='indices of GPUs to enable (default: all)')
|
| 77 |
+
|
| 78 |
+
# custom cli options to modify configuration from default values given in json file.
|
| 79 |
+
CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
|
| 80 |
+
options = [
|
| 81 |
+
CustomArgs(['--lr', '--learning_rate'], type=float, target='optimizer;args;lr'),
|
| 82 |
+
CustomArgs(['--bs', '--batch_size'], type=int, target='data_loader;args;batch_size')
|
| 83 |
+
]
|
| 84 |
+
config = ConfigParser.from_args(args, options)
|
| 85 |
+
main(config)
|