File size: 8,305 Bytes
5f0437a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA').
#
# All rights reserved.
# This work should only be used for nonprofit purposes.
#
# By downloading and/or using any of these files, you implicitly agree to all the
# terms of the license, as specified in the document LICENSE.txt
# (included in this package) and online at
# http://www.grip.unina.it/download/LICENSE_OPEN.txt
"""
Created in September 2022
@author: fabrizio.guillaro
"""
import sys, os
path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
if path not in sys.path:
sys.path.insert(0, path)
import argparse
import logging
import time
import timeit
import gc
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.optim
torch.autograd.set_detect_anomaly(True)
from tensorboardX import SummaryWriter
from lib.config import config, update_config
from lib.core.function import train, validate
from lib.utils import get_model, get_optimizer
from lib.utils import create_logger, FullModel, adjust_learning_rate
from dataset.data_core import myDataset
import albumentations
def main():
parser = argparse.ArgumentParser(description='Train TruFor')
parser.add_argument('-exp', '--experiment', type=str)
parser.add_argument('-g', '--gpu', type=int, default=[0], nargs="+", help='device(s)')
parser.add_argument('opts', help='other options', default=None, nargs=argparse.REMAINDER)
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)
args.gpu = range(len(args.gpu))
update_config(config, args)
logger, final_output_dir, tb_log_dir = create_logger(config, f'{args.experiment}', 'train')
logger.info(config)
logger.info('\n')
# cudnn setting
cudnn.benchmark = config.CUDNN.BENCHMARK
cudnn.deterministic = config.CUDNN.DETERMINISTIC
cudnn.enabled = config.CUDNN.ENABLED
gpus = list(config.GPUS)
writer_dict = {
'writer': SummaryWriter(tb_log_dir),
'train_global_steps': 0,
'valid_global_steps': 0,
}
if config.TRAIN.AUG is not None:
aug_train = albumentations.load(config.TRAIN.AUG, data_format='yaml')
else:
aug_train = None
if config.VALID.AUG is not None:
aug_valid = albumentations.load(config.VALID.AUG, data_format='yaml')
else:
aug_valid = None
logger.info(f'Train augmentation: {config.TRAIN.AUG} {aug_train}')
logger.info(f'Validation augmentation: {config.VALID.AUG} {aug_valid}')
crop_size = (config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
train_dataset = myDataset(config, crop_size=crop_size, grid_crop=False, mode='train', aug=aug_train)
valid_dataset = myDataset(config, crop_size=None, grid_crop=False, mode="valid", aug=aug_valid,
max_dim=config.VALID.MAX_SIZE)
trainloader = torch.utils.data.DataLoader(
train_dataset,
batch_size = config.TRAIN.BATCH_SIZE_PER_GPU*len(gpus),
shuffle = config.TRAIN.SHUFFLE,
num_workers = config.WORKERS)
validloader = torch.utils.data.DataLoader(
valid_dataset,
batch_size = 1, # 1 to allow arbitrary input sizes
shuffle = False, # must be False to get accurate filename
num_workers = config.WORKERS)
# model
model = get_model(config)
model = torch.nn.DataParallel(model, device_ids=gpus).cuda()
model = FullModel(model, config)
# optimizer
optimizer = get_optimizer(model, config)
epoch_iters = np.int32(train_dataset.__len__() / config.TRAIN.BATCH_SIZE_PER_GPU / len(gpus))
best_key = config.VALID.BEST_KEY
if 'loss' in best_key:
best_value = np.inf
else:
best_value = 0
logger.info(f'best valid key: {best_key}')
last_epoch = 0
if not config.TRAIN.PRETRAINING == '' and not config.TRAIN.PRETRAINING == None:
model_state_file = config.TRAIN.PRETRAINING
assert os.path.isfile(model_state_file)
checkpoint = torch.load(model_state_file, map_location=lambda storage, loc: storage)
state_dict = checkpoint['state_dict']
try:
model.model.module.load_state_dict(state_dict, strict=False)
except:
state_dict = {k: state_dict[k] for k in state_dict if not k.startswith('detection')}
model.model.module.load_state_dict(state_dict, strict=False)
del checkpoint
del state_dict
logger.info("=> loaded pretraining ({})".format(model_state_file))
if config.TRAIN.RESUME:
model_state_file = os.path.join(final_output_dir, 'checkpoint.pth.tar')
if os.path.isfile(model_state_file):
checkpoint = torch.load(model_state_file, map_location=lambda storage, loc: storage)
best_value = checkpoint['best_value']
assert checkpoint['best_key']==best_key
last_epoch = checkpoint['epoch']
model.model.module.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
logger.info("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
writer_dict['train_global_steps'] = last_epoch
else:
logger.info("No previous checkpoint.")
end_epoch = config.TRAIN.END_EPOCH + config.TRAIN.EXTRA_EPOCH
num_iters = config.TRAIN.END_EPOCH * epoch_iters
start_epoch = last_epoch
if config.VALID.FIRST_VALID:
start_epoch = start_epoch -1
for epoch in range(start_epoch, end_epoch):
# train
if epoch>=last_epoch:
train_dataset.shuffle() # for class-balanced sampling
print(f'TRAINING epoch {epoch}:')
train(epoch, config.TRAIN.END_EPOCH,
epoch_iters, config.TRAIN.LR, num_iters,
trainloader, optimizer, model, writer_dict,
adjust_learning_rate=adjust_learning_rate)
torch.cuda.empty_cache()
gc.collect()
time.sleep(1.0)
logger.info('=> saving checkpoint to {}'.format(
os.path.join(final_output_dir, 'checkpoint.pth.tar')))
torch.save({
'epoch': epoch + 1,
'best_value': best_value,
'best_key': best_key,
'state_dict': model.model.module.state_dict(),
'optimizer': optimizer.state_dict(),
}, os.path.join(final_output_dir, 'checkpoint.pth.tar'))
# valid
print(f'VALIDATION epoch {epoch}:')
writer_dict['valid_global_steps'] = epoch
value_valid, IoU_array, confusion_matrix = \
validate(config, validloader, model, writer_dict, "valid")
torch.cuda.empty_cache()
gc.collect()
time.sleep(3.0)
if 'loss' in best_key:
if value_valid[best_key] < best_value: # smallest loss
best_value = value_valid[best_key]
torch.save({
'epoch': epoch + 1,
'best_value': best_value,
'best_key': best_key,
'state_dict': model.model.module.state_dict(),
'optimizer': optimizer.state_dict(),
}, os.path.join(final_output_dir, 'best.pth.tar'))
logger.info("best.pth.tar updated.")
elif value_valid[best_key] > best_value: # highest metric
best_value = value_valid[best_key]
torch.save({
'epoch': epoch + 1,
'best_value': best_value,
'best_key': best_key,
'state_dict': model.model.module.state_dict(),
'optimizer': optimizer.state_dict(),
}, os.path.join(final_output_dir, 'best.pth.tar'))
logger.info("best.pth.tar updated.")
msg = '(Valid) Loss: {:.3f}, Best_{:s}: {: 4.4f}'.format(
value_valid['loss'], best_key, best_value)
logging.info(msg)
logging.info(IoU_array)
logging.info("confusion_matrix:")
logging.info(confusion_matrix)
if __name__ == '__main__':
main()
|