HirraA commited on
Commit
c44d232
·
verified ·
1 Parent(s): 7603ac0

Upload 9 files

Browse files
Files changed (9) hide show
  1. .gitignore +112 -0
  2. LICENSE +21 -0
  3. README.md +149 -3
  4. config.json +57 -0
  5. inference.py +116 -0
  6. parse_config.py +163 -0
  7. requirements.txt +7 -0
  8. test.py +77 -0
  9. train.py +85 -0
.gitignore ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ env/
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+
28
+ # PyInstaller
29
+ # Usually these files are written by a python script from a template
30
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .coverage
42
+ .coverage.*
43
+ .cache
44
+ nosetests.xml
45
+ coverage.xml
46
+ *.cover
47
+ .hypothesis/
48
+
49
+ # Translations
50
+ *.mo
51
+ *.pot
52
+
53
+ # Django stuff:
54
+ *.log
55
+ local_settings.py
56
+
57
+ # Flask stuff:
58
+ instance/
59
+ .webassets-cache
60
+
61
+ # Scrapy stuff:
62
+ .scrapy
63
+
64
+ # Sphinx documentation
65
+ docs/_build/
66
+
67
+ # PyBuilder
68
+ target/
69
+
70
+ # Jupyter Notebook
71
+ .ipynb_checkpoints
72
+
73
+ # pyenv
74
+ .python-version
75
+
76
+ # celery beat schedule file
77
+ celerybeat-schedule
78
+
79
+ # SageMath parsed files
80
+ *.sage.py
81
+
82
+ # dotenv
83
+ .env
84
+
85
+ # virtualenv
86
+ .venv
87
+ venv/
88
+ ENV/
89
+
90
+ # Spyder project settings
91
+ .spyderproject
92
+ .spyproject
93
+
94
+ # Rope project settings
95
+ .ropeproject
96
+
97
+ # mkdocs documentation
98
+ /site
99
+
100
+ # mypy
101
+ .mypy_cache/
102
+
103
+ # input data, saved log, checkpoints
104
+ data/
105
+ input/
106
+ saved/
107
+ datasets/
108
+
109
+ # editor, os cache directory
110
+ .vscode/
111
+ .idea/
112
+ __MACOSX/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Technische Universität Kaiserslautern (TUK) & National University of Sciences and Technology (NUST)
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,149 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NUST Wheat Rust Disease (NWRD): Semantic Segmentation using Suppervised Deep Learning
2
+ Semantic segmentation of wheat yellow/stripe rust disease images to segment out rust and non-rust pixels using supervised deep learning.
3
+
4
+ This repo contains the source code for the study presented in the work:
5
+ `The NWRD Dataset: An Open-Source Annotated Segmentation Dataset of Diseased Wheat Crop`.
6
+
7
+ ## Dataset
8
+ The NWRD dataset is a real-world segmentation dataset of wheat rust diseased and healthy leaf images specifically constructed for semantic segmentation of wheat rust disease.
9
+ The NWRD dataset consists of 100 images in total at this moment.
10
+
11
+ Sample images from The NWRD dataset; annotated images showing rust disease along with their binary masks:
12
+ ![Sample images from The NWRD dataset; annotated images showing rust disease along with their binary masks](https://github.com/saadulkh/nwrd/assets/38633812/c8677336-82a0-4637-a3f8-61f5cedbad37)
13
+
14
+
15
+ Dataset is available at: https://dll.seecs.nust.edu.pk/downloads/
16
+
17
+ ### Directory Structure
18
+ The NWRD dataset images are available in `.jpg` format and the annotated binary masks are available in `.png` format. Below is the directory structure of this dataset:
19
+ ```
20
+ NWRD
21
+ ├── test
22
+ │   ├── images
23
+ │   └── masks
24
+ └── train
25
+ ├── images
26
+ └── masks
27
+ ```
28
+
29
+ ### Data Splits
30
+ Here are the data splits of the NWRD dataset:
31
+ | Split | Percentage |
32
+ |-----------|-----------|
33
+ | Train + Valid | 90 |
34
+ | Test | 10 |
35
+ | Total | 100 |
36
+
37
+ The experimentation with 22 images was done with the following set of images:
38
+ | Split | Images |
39
+ |-----------|-----------|
40
+ | Train + Valid | 2, 7, 14, 30, 64, 83, 84, 90, 94, 95, 100, 118, 124, 125, 132, 133, 136, 137, 138, 146 |
41
+ | Test | 67, 123 |
42
+
43
+ ## Requirements
44
+ * Python
45
+ * Pytorch
46
+ * Torchvision
47
+ * Numpy
48
+ * Tqdm
49
+ * Tensorboard
50
+ * Scikit-learn
51
+ * Pandas
52
+
53
+ ## Usage
54
+ Different aspects of the training can be tunned from the `config.json` file.
55
+ Use `python train.py -c config.json` to run code.
56
+
57
+ ### Config file format
58
+ Config files are in `.json` format:
59
+ ```javascript
60
+ {
61
+ "name": "WRS", // training session name
62
+ "n_gpu": 1, // number of GPUs to use for training.
63
+
64
+ "arch": {
65
+ "type": "UNet", // name of model architecture to train
66
+ "args": {
67
+ "n_channels": 3,
68
+ "n_classes": 2
69
+ } // pass arguments to the model
70
+ },
71
+ "data_loader": {
72
+ "type": "PatchedDataLoader", // selecting data loader
73
+ "args":{
74
+ "data_dir": "data/", // dataset path
75
+ "patch_size": 128, // patch size
76
+ "batch_size": 64, // batch size
77
+ "patch_stride": 32, // patch overlapping stride
78
+ "target_dist": 0.01, // least percentage of rust pixels in a patch
79
+ "shuffle": true, // shuffle training data before
80
+ "validation_split": 0.1, // size of validation dataset. float(portion) or int(number of samples)
81
+ "num_workers": 2 // number of cpu processes to be used for data loading
82
+ }
83
+ },
84
+ "optimizer": {
85
+ "type": "RMSprop",
86
+ "args":{
87
+ "lr": 1e-6, // learning rate
88
+ "weight_decay": 0
89
+ }
90
+ },
91
+ "loss": "focal_loss", // loss function
92
+ "metrics": [ // list of metrics to evaluate
93
+ "precision",
94
+ "recall",
95
+ "f1_score"
96
+ ],
97
+ "lr_scheduler": {
98
+ "type": "ExponentialLR", // learning rate scheduler
99
+ "args": {
100
+ "gamma": 0.998
101
+ }
102
+ },
103
+ "trainer": {
104
+ "epochs": 500, // number of training epochs
105
+ "adaptive_step": 5, // update dataset after every adaptive_step epochs
106
+
107
+ "save_dir": "saved/",
108
+ "save_period": 1, // save checkpoints every save_period epochs
109
+ "verbosity": 2, // 0: quiet, 1: per epoch, 2: full
110
+
111
+ "monitor": "min val_loss", // mode and metric for model performance monitoring. set 'off' to disable.
112
+ "early_stop": 50, // early stop
113
+
114
+ "tensorboard": true // enable tensorboard visualization
115
+ }
116
+ }
117
+ ```
118
+
119
+ ### Using config files
120
+ Modify the configurations in `.json` config files, then run:
121
+
122
+ ```
123
+ python train.py --config config.json
124
+ ```
125
+
126
+ ### Resuming from checkpoints
127
+ You can resume from a previously saved checkpoint by:
128
+
129
+ ```
130
+ python train.py --resume path/to/checkpoint
131
+ ```
132
+
133
+ ### Using Multiple GPU
134
+ You can enable multi-GPU training by setting `n_gpu` argument of the config file to larger number.
135
+ If configured to use smaller number of gpu than available, first n devices will be used by default.
136
+ Specify indices of available GPUs by cuda environmental variable.
137
+ ```
138
+ python train.py --device 2,3 -c config.json
139
+ ```
140
+ This is equivalent to
141
+ ```
142
+ CUDA_VISIBLE_DEVICES=2,3 python train.py -c config.py
143
+ ```
144
+ ## License
145
+ This project is licensed under the MIT License. See [LICENSE](LICENSE) for more details
146
+
147
+ ## Acknowledgements
148
+ * This research project was funded by German Academic Exchange Service (DAAD).
149
+ * This project follows the template provided by [victoresque](https://github.com/victoresque/pytorch-template)
config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "WRS_Adaptive_Patching",
3
+ "n_gpu": 1,
4
+
5
+ "arch": {
6
+ "type": "UNet",
7
+ "args": {
8
+ "n_channels": 3,
9
+ "n_classes": 2
10
+ }
11
+ },
12
+ "data_loader": {
13
+ "type": "PatchedDataLoader",
14
+ "args":{
15
+ "data_dir": "/scratch/sukhan/wrs/datads-cv/",
16
+ "patch_size": 128,
17
+ "batch_size": 64,
18
+ "patch_stride": 32,
19
+ "target_dist": 0.01,
20
+ "shuffle": true,
21
+ "validation_split": 0.1,
22
+ "num_workers": 8
23
+ }
24
+ },
25
+ "optimizer": {
26
+ "type": "RMSprop",
27
+ "args":{
28
+ "lr": 1e-6,
29
+ "weight_decay": 0
30
+ }
31
+ },
32
+ "loss": "focal_loss",
33
+ "metrics": [
34
+ "precision",
35
+ "recall",
36
+ "f1_score"
37
+ ],
38
+ "lr_scheduler": {
39
+ "type": "ExponentialLR",
40
+ "args": {
41
+ "gamma": 0.998
42
+ }
43
+ },
44
+ "trainer": {
45
+ "epochs": 500,
46
+ "adaptive_step": 5,
47
+
48
+ "save_dir": "/scratch/sukhan/wrs/saved/",
49
+ "save_period": 1,
50
+ "verbosity": 2,
51
+
52
+ "monitor": "min val_loss",
53
+ "early_stop": 50,
54
+
55
+ "tensorboard": true
56
+ }
57
+ }
inference.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import collections
3
+ import glob
4
+ import shutil
5
+ import sys
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+ import PIL.Image as Image
9
+ import torch
10
+ from torchvision.transforms import transforms
11
+ from tqdm import tqdm
12
+ import data_loader.data_loaders as module_data
13
+ import model.model as module_arch
14
+ from parse_config import ConfigParser
15
+
16
+
17
+ def main(config):
18
+ logger = config.get_logger('infernece')
19
+
20
+ # setup data_loader instances
21
+ data_loader = getattr(module_data, config['data_loader']['type'])(
22
+ config['data_loader']['args']['data_dir'],
23
+ patch_size=config['data_loader']['args']['patch_size'],
24
+ batch_size=512,
25
+ shuffle=False,
26
+ validation_split=0.0,
27
+ training=False,
28
+ num_workers=2
29
+ )
30
+
31
+ # build model architecture
32
+ model = config.init_obj('arch', module_arch)
33
+ logger.info(model)
34
+
35
+ # prepare model for testing
36
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
37
+ logger.info('Loading checkpoint: {} ...'.format(config.resume))
38
+ checkpoint = torch.load(config.resume, map_location=device)
39
+ state_dict = checkpoint['state_dict']
40
+ if config['n_gpu'] > 1:
41
+ model = torch.nn.DataParallel(model)
42
+ model.load_state_dict(state_dict)
43
+ model = model.to(device)
44
+
45
+ model.eval()
46
+ with torch.no_grad():
47
+ for i, (data, target) in enumerate(tqdm(data_loader)):
48
+ data, target = data.to(device), target.to(device)
49
+ output = model(data)
50
+
51
+ batch_size = data_loader.batch_size
52
+ patch_idx = torch.arange(
53
+ batch_size * i, batch_size * i + data.shape[0])
54
+ pred = torch.argmax(output, dim=1)
55
+ data_loader.dataset.patches.store_data(
56
+ patch_idx, [pred.unsqueeze(1)])
57
+
58
+ preds = [(data_loader.dataset.patches.combine(idx, data_idx=0).cpu(),
59
+ data_loader.dataset.data[idx])
60
+ for idx in range(len(data_loader.dataset.data))]
61
+ trsfm = transforms.ToPILImage()
62
+
63
+ out_dir = list(config.save_dir.parts)
64
+ out_dir[-3] = 'output'
65
+ out_dir = Path(*out_dir)
66
+ out_dir.mkdir(parents=True, exist_ok=True)
67
+ for pred, path in preds:
68
+ filename = Path(path).stem + '.png'
69
+ pred = trsfm(pred.float())
70
+ pred.save(out_dir / filename)
71
+
72
+
73
+ if __name__ == '__main__':
74
+ args = argparse.ArgumentParser(description='PyTorch Template')
75
+ args.add_argument('-c', '--config', default=None, type=str,
76
+ help='config file path (default: None)')
77
+ args.add_argument('-r', '--resume', default=None, type=str,
78
+ help='path to latest checkpoint (default: None)')
79
+ args.add_argument('-d', '--device', default=None, type=str,
80
+ help='indices of GPUs to enable (default: all)')
81
+ args.add_argument('--data', default=None, type=str,
82
+ help='path to data (default: None)')
83
+
84
+ run_id = datetime.now().strftime(r'%m%d_%H%M%S')
85
+ dst_data = Path('.data/', run_id)
86
+
87
+ data_dir = dst_data / 'test' / 'images'
88
+ masks_dir = dst_data / 'test' / 'masks'
89
+
90
+ data_dir.mkdir(parents=True, exist_ok=True)
91
+ masks_dir.mkdir(parents=True, exist_ok=True)
92
+
93
+ if '--data' in sys.argv:
94
+ src_data = Path(sys.argv[sys.argv.index('--data') + 1])
95
+ if src_data.is_file():
96
+ shutil.copy(src_data, data_dir)
97
+ mask = Image.new('1', Image.open(src_data).size)
98
+ mask.save(masks_dir / (src_data.stem + '.png'), 'PNG')
99
+ else:
100
+ for filename in glob.glob('*.jpg', root_dir=src_data):
101
+ file_path = src_data / filename
102
+ if file_path.is_file():
103
+ shutil.copy(file_path, data_dir)
104
+ mask = Image.new('1', Image.open(file_path).size)
105
+ mask.save(masks_dir / (file_path.stem + '.png'), 'PNG')
106
+ sys.argv += ['--data_dir', str(dst_data)]
107
+
108
+ # custom cli options to modify configuration from default values given in json file.
109
+ CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
110
+ options = [
111
+ CustomArgs(['--data_dir'], type=str,
112
+ target='data_loader;args;data_dir')
113
+ ]
114
+ config = ConfigParser.from_args(args, options)
115
+ main(config)
116
+ shutil.rmtree(dst_data)
parse_config.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from datetime import datetime
4
+ from functools import partial, reduce
5
+ from operator import getitem
6
+ from pathlib import Path
7
+ from logger import setup_logging
8
+ from utils import read_json, write_json
9
+
10
+
11
+ class ConfigParser:
12
+ def __init__(self, config, resume=None, modification=None, run_id=None):
13
+ """
14
+ class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving
15
+ and logging module.
16
+ :param config: Dict containing configurations, hyperparameters for training. contents of `config.json` file for example.
17
+ :param resume: String, path to the checkpoint being loaded.
18
+ :param modification: Dict keychain:value, specifying position values to be replaced from config dict.
19
+ :param run_id: Unique Identifier for training processes. Used to save checkpoints and training log. Timestamp is being used as default
20
+ """
21
+ # load config file and apply modification
22
+ self._config = _update_config(config, modification)
23
+ self.resume = resume
24
+
25
+ # set save_dir where trained model and log will be saved.
26
+ save_dir = Path(self.config['trainer']['save_dir'])
27
+
28
+ exper_name = self.config['name']
29
+ if run_id is None: # use timestamp as default run-id
30
+ run_id = datetime.now().strftime(r'%m%d_%H%M%S')
31
+ self._save_dir = save_dir / 'models' / exper_name / run_id
32
+ self._log_dir = save_dir / 'log' / exper_name / run_id
33
+
34
+ # make directory for saving checkpoints and log.
35
+ exist_ok = run_id == ''
36
+ self.save_dir.mkdir(parents=True, exist_ok=exist_ok)
37
+ self.log_dir.mkdir(parents=True, exist_ok=exist_ok)
38
+
39
+ # save updated config file to the checkpoint dir
40
+ write_json(self.config, self.save_dir / 'config.json')
41
+
42
+ # configure logging module
43
+ setup_logging(self.log_dir)
44
+ self.log_levels = {
45
+ 0: logging.WARNING,
46
+ 1: logging.INFO,
47
+ 2: logging.DEBUG
48
+ }
49
+
50
+ @classmethod
51
+ def from_args(cls, args, options=''):
52
+ """
53
+ Initialize this class from some cli arguments. Used in train, test.
54
+ """
55
+ for opt in options:
56
+ args.add_argument(*opt.flags, default=None, type=opt.type)
57
+ if not isinstance(args, tuple):
58
+ args = args.parse_args()
59
+
60
+ if args.device is not None:
61
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.device
62
+ if args.resume is not None:
63
+ resume = Path(args.resume)
64
+ cfg_fname = resume.parent / 'config.json'
65
+ else:
66
+ msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
67
+ assert args.config is not None, msg_no_cfg
68
+ resume = None
69
+ cfg_fname = Path(args.config)
70
+
71
+ config = read_json(cfg_fname)
72
+ if args.config and resume:
73
+ # update new config for fine-tuning
74
+ config.update(read_json(args.config))
75
+
76
+ # parse custom cli options into dictionary
77
+ modification = {opt.target: getattr(args, _get_opt_name(opt.flags)) for opt in options}
78
+ return cls(config, resume, modification)
79
+
80
+ def init_obj(self, name, module, *args, **kwargs):
81
+ """
82
+ Finds a function handle with the name given as 'type' in config, and returns the
83
+ instance initialized with corresponding arguments given.
84
+
85
+ `object = config.init_obj('name', module, a, b=1)`
86
+ is equivalent to
87
+ `object = module.name(a, b=1)`
88
+ """
89
+ module_name = self[name]['type']
90
+ module_args = dict(self[name]['args'])
91
+ assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
92
+ module_args.update(kwargs)
93
+ return getattr(module, module_name)(*args, **module_args)
94
+
95
+ def init_ftn(self, name, module, *args, **kwargs):
96
+ """
97
+ Finds a function handle with the name given as 'type' in config, and returns the
98
+ function with given arguments fixed with functools.partial.
99
+
100
+ `function = config.init_ftn('name', module, a, b=1)`
101
+ is equivalent to
102
+ `function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`.
103
+ """
104
+ module_name = self[name]['type']
105
+ module_args = dict(self[name]['args'])
106
+ assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
107
+ module_args.update(kwargs)
108
+ return partial(getattr(module, module_name), *args, **module_args)
109
+
110
+ def __getitem__(self, name):
111
+ """Access items like ordinary dict."""
112
+ return self.config[name]
113
+
114
+ def get_logger(self, name, verbosity=2):
115
+ msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(
116
+ verbosity, self.log_levels.keys())
117
+ assert verbosity in self.log_levels, msg_verbosity
118
+ logger = logging.getLogger(name)
119
+ logger.setLevel(self.log_levels[verbosity])
120
+ return logger
121
+
122
+ # setting read-only attributes
123
+ @property
124
+ def config(self):
125
+ return self._config
126
+
127
+ @property
128
+ def save_dir(self):
129
+ return self._save_dir
130
+
131
+ @property
132
+ def log_dir(self):
133
+ return self._log_dir
134
+
135
+ # helper functions to update config dict with custom cli options
136
+
137
+
138
+ def _update_config(config, modification):
139
+ if modification is None:
140
+ return config
141
+
142
+ for k, v in modification.items():
143
+ if v is not None:
144
+ _set_by_path(config, k, v)
145
+ return config
146
+
147
+
148
+ def _get_opt_name(flags):
149
+ for flg in flags:
150
+ if flg.startswith('--'):
151
+ return flg.replace('--', '')
152
+ return flags[0].replace('--', '')
153
+
154
+
155
+ def _set_by_path(tree, keys, value):
156
+ """Set a value in a nested object in tree by sequence of keys."""
157
+ keys = keys.split(';')
158
+ _get_by_path(tree, keys[:-1])[keys[-1]] = value
159
+
160
+
161
+ def _get_by_path(tree, keys):
162
+ """Access a nested object in tree by sequence of keys."""
163
+ return reduce(getitem, keys, tree)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch>=1.12
2
+ torchvision
3
+ numpy
4
+ tqdm
5
+ tensorboard>=2.9.1
6
+ scikit-learn
7
+ pandas
test.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from tqdm import tqdm
4
+ import data_loader.data_loaders as module_data
5
+ import model.loss as module_loss
6
+ import model.metric as module_metric
7
+ import model.model as module_arch
8
+ from parse_config import ConfigParser
9
+
10
+
11
+ def main(config):
12
+ logger = config.get_logger('test')
13
+
14
+ # setup data_loader instances
15
+ data_loader = getattr(module_data, config['data_loader']['type'])(
16
+ config['data_loader']['args']['data_dir'],
17
+ patch_size=config['data_loader']['args']['patch_size'],
18
+ batch_size=512,
19
+ shuffle=False,
20
+ validation_split=0.0,
21
+ training=False,
22
+ num_workers=2
23
+ )
24
+
25
+ # build model architecture
26
+ model = config.init_obj('arch', module_arch)
27
+ logger.info(model)
28
+
29
+ # prepare model for testing
30
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31
+ logger.info('Loading checkpoint: {} ...'.format(config.resume))
32
+ checkpoint = torch.load(config.resume, map_location=device)
33
+ state_dict = checkpoint['state_dict']
34
+ if config['n_gpu'] > 1:
35
+ model = torch.nn.DataParallel(model)
36
+ model.load_state_dict(state_dict)
37
+ model = model.to(device)
38
+
39
+ # get function handles of loss and metrics
40
+ loss_fn = getattr(module_loss, config['loss'])
41
+ metric_fns = [getattr(module_metric, met) for met in config['metrics']]
42
+
43
+ total_loss = 0.0
44
+ total_metrics = torch.zeros(len(metric_fns))
45
+
46
+ model.eval()
47
+ with torch.no_grad():
48
+ for i, (data, target) in enumerate(tqdm(data_loader)):
49
+ data, target = data.to(device), target.to(device)
50
+ output = model(data)
51
+
52
+ # computing loss, metrics on test set
53
+ loss = loss_fn(output, target)
54
+ batch_size = data.shape[0]
55
+ total_loss += loss.item() * batch_size
56
+ for i, metric in enumerate(metric_fns):
57
+ total_metrics[i] += metric(output, target) * batch_size
58
+
59
+ n_samples = len(data_loader.sampler)
60
+ log = {'loss': total_loss / n_samples}
61
+ log.update({
62
+ met.__name__: total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns)
63
+ })
64
+ logger.info(log)
65
+
66
+
67
+ if __name__ == '__main__':
68
+ args = argparse.ArgumentParser(description='PyTorch Template')
69
+ args.add_argument('-c', '--config', default=None, type=str,
70
+ help='config file path (default: None)')
71
+ args.add_argument('-r', '--resume', default=None, type=str,
72
+ help='path to latest checkpoint (default: None)')
73
+ args.add_argument('-d', '--device', default=None, type=str,
74
+ help='indices of GPUs to enable (default: all)')
75
+
76
+ config = ConfigParser.from_args(args)
77
+ main(config)
train.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import collections
3
+ import numpy as np
4
+ import torch
5
+ import data_loader.data_loaders as module_data
6
+ import model.loss as module_loss
7
+ import model.metric as module_metric
8
+ import model.model as module_arch
9
+ from parse_config import ConfigParser
10
+ from trainer import Trainer
11
+ from utils import prepare_device
12
+
13
+
14
+ # fix random seeds for reproducibility
15
+ SEED = 123
16
+ torch.manual_seed(SEED)
17
+ torch.backends.cudnn.deterministic = True
18
+ torch.backends.cudnn.benchmark = False
19
+ np.random.seed(SEED)
20
+
21
+
22
+ def main(config):
23
+ logger = config.get_logger('train')
24
+
25
+ # setup data_loader instances
26
+ data_loader = config.init_obj('data_loader', module_data)
27
+ valid_data_loader = data_loader.split_validation()
28
+
29
+ # setup data_loader for inference
30
+ data_loader.inference = getattr(module_data, config['data_loader']['type'])(
31
+ config['data_loader']['args']['data_dir'],
32
+ patch_size=config['data_loader']['args']['patch_size'],
33
+ batch_size=config['data_loader']['args']['batch_size'],
34
+ shuffle=False,
35
+ validation_split=0.0,
36
+ training=True,
37
+ num_workers=config['data_loader']['args']['num_workers']
38
+ )
39
+
40
+ # build model architecture, then print to console
41
+ model = config.init_obj('arch', module_arch)
42
+ logger.info(model)
43
+
44
+ # prepare for (multi-device) GPU training
45
+ device, device_ids = prepare_device(config['n_gpu'])
46
+ model = model.to(device)
47
+ if len(device_ids) > 1:
48
+ model = torch.nn.DataParallel(model, device_ids=device_ids)
49
+
50
+ # get function handles of loss and metrics
51
+ criterion = getattr(module_loss, config['loss'])
52
+ metrics = [getattr(module_metric, met) for met in config['metrics']]
53
+
54
+ # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
55
+ trainable_params = filter(lambda p: p.requires_grad, model.parameters())
56
+ optimizer = config.init_obj('optimizer', torch.optim, trainable_params)
57
+ lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)
58
+
59
+ trainer = Trainer(model, criterion, metrics, optimizer,
60
+ config=config,
61
+ device=device,
62
+ data_loader=data_loader,
63
+ valid_data_loader=valid_data_loader,
64
+ lr_scheduler=lr_scheduler)
65
+
66
+ trainer.train()
67
+
68
+
69
+ if __name__ == '__main__':
70
+ args = argparse.ArgumentParser(description='PyTorch Template')
71
+ args.add_argument('-c', '--config', default=None, type=str,
72
+ help='config file path (default: None)')
73
+ args.add_argument('-r', '--resume', default=None, type=str,
74
+ help='path to latest checkpoint (default: None)')
75
+ args.add_argument('-d', '--device', default=None, type=str,
76
+ help='indices of GPUs to enable (default: all)')
77
+
78
+ # custom cli options to modify configuration from default values given in json file.
79
+ CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
80
+ options = [
81
+ CustomArgs(['--lr', '--learning_rate'], type=float, target='optimizer;args;lr'),
82
+ CustomArgs(['--bs', '--batch_size'], type=int, target='data_loader;args;batch_size')
83
+ ]
84
+ config = ConfigParser.from_args(args, options)
85
+ main(config)