|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """
|
| Created in September 2022
|
| @author: fabrizio.guillaro
|
| """
|
|
|
| import sys, os
|
| path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
|
| if path not in sys.path:
|
| sys.path.insert(0, path)
|
|
|
| import argparse
|
|
|
| import logging
|
| import time
|
| import timeit
|
|
|
| import gc
|
| import numpy as np
|
|
|
| import torch
|
| import torch.backends.cudnn as cudnn
|
| import torch.optim
|
| torch.autograd.set_detect_anomaly(True)
|
| from tensorboardX import SummaryWriter
|
|
|
| from lib.config import config, update_config
|
| from lib.core.function import train, validate
|
| from lib.utils import get_model, get_optimizer
|
| from lib.utils import create_logger, FullModel, adjust_learning_rate
|
|
|
| from dataset.data_core import myDataset
|
| import albumentations
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(description='Train TruFor')
|
| parser.add_argument('-exp', '--experiment', type=str)
|
| parser.add_argument('-g', '--gpu', type=int, default=[0], nargs="+", help='device(s)')
|
| parser.add_argument('opts', help='other options', default=None, nargs=argparse.REMAINDER)
|
| args = parser.parse_args()
|
|
|
| os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)
|
| args.gpu = range(len(args.gpu))
|
|
|
| update_config(config, args)
|
|
|
| logger, final_output_dir, tb_log_dir = create_logger(config, f'{args.experiment}', 'train')
|
| logger.info(config)
|
| logger.info('\n')
|
|
|
|
|
| cudnn.benchmark = config.CUDNN.BENCHMARK
|
| cudnn.deterministic = config.CUDNN.DETERMINISTIC
|
| cudnn.enabled = config.CUDNN.ENABLED
|
|
|
| gpus = list(config.GPUS)
|
|
|
| writer_dict = {
|
| 'writer': SummaryWriter(tb_log_dir),
|
| 'train_global_steps': 0,
|
| 'valid_global_steps': 0,
|
| }
|
|
|
| if config.TRAIN.AUG is not None:
|
| aug_train = albumentations.load(config.TRAIN.AUG, data_format='yaml')
|
| else:
|
| aug_train = None
|
|
|
| if config.VALID.AUG is not None:
|
| aug_valid = albumentations.load(config.VALID.AUG, data_format='yaml')
|
| else:
|
| aug_valid = None
|
|
|
| logger.info(f'Train augmentation: {config.TRAIN.AUG} {aug_train}')
|
| logger.info(f'Validation augmentation: {config.VALID.AUG} {aug_valid}')
|
|
|
| crop_size = (config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
|
| train_dataset = myDataset(config, crop_size=crop_size, grid_crop=False, mode='train', aug=aug_train)
|
| valid_dataset = myDataset(config, crop_size=None, grid_crop=False, mode="valid", aug=aug_valid,
|
| max_dim=config.VALID.MAX_SIZE)
|
|
|
| trainloader = torch.utils.data.DataLoader(
|
| train_dataset,
|
| batch_size = config.TRAIN.BATCH_SIZE_PER_GPU*len(gpus),
|
| shuffle = config.TRAIN.SHUFFLE,
|
| num_workers = config.WORKERS)
|
|
|
| validloader = torch.utils.data.DataLoader(
|
| valid_dataset,
|
| batch_size = 1,
|
| shuffle = False,
|
| num_workers = config.WORKERS)
|
|
|
|
|
| model = get_model(config)
|
| model = torch.nn.DataParallel(model, device_ids=gpus).cuda()
|
| model = FullModel(model, config)
|
|
|
|
|
| optimizer = get_optimizer(model, config)
|
|
|
| epoch_iters = np.int32(train_dataset.__len__() / config.TRAIN.BATCH_SIZE_PER_GPU / len(gpus))
|
|
|
| best_key = config.VALID.BEST_KEY
|
| if 'loss' in best_key:
|
| best_value = np.inf
|
| else:
|
| best_value = 0
|
| logger.info(f'best valid key: {best_key}')
|
|
|
|
|
| last_epoch = 0
|
| if not config.TRAIN.PRETRAINING == '' and not config.TRAIN.PRETRAINING == None:
|
| model_state_file = config.TRAIN.PRETRAINING
|
| assert os.path.isfile(model_state_file)
|
| checkpoint = torch.load(model_state_file, map_location=lambda storage, loc: storage)
|
| state_dict = checkpoint['state_dict']
|
| try:
|
| model.model.module.load_state_dict(state_dict, strict=False)
|
| except:
|
| state_dict = {k: state_dict[k] for k in state_dict if not k.startswith('detection')}
|
| model.model.module.load_state_dict(state_dict, strict=False)
|
| del checkpoint
|
| del state_dict
|
| logger.info("=> loaded pretraining ({})".format(model_state_file))
|
|
|
|
|
| if config.TRAIN.RESUME:
|
| model_state_file = os.path.join(final_output_dir, 'checkpoint.pth.tar')
|
| if os.path.isfile(model_state_file):
|
| checkpoint = torch.load(model_state_file, map_location=lambda storage, loc: storage)
|
| best_value = checkpoint['best_value']
|
| assert checkpoint['best_key']==best_key
|
| last_epoch = checkpoint['epoch']
|
| model.model.module.load_state_dict(checkpoint['state_dict'])
|
| optimizer.load_state_dict(checkpoint['optimizer'])
|
| logger.info("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
|
| writer_dict['train_global_steps'] = last_epoch
|
| else:
|
| logger.info("No previous checkpoint.")
|
|
|
|
|
| end_epoch = config.TRAIN.END_EPOCH + config.TRAIN.EXTRA_EPOCH
|
| num_iters = config.TRAIN.END_EPOCH * epoch_iters
|
| start_epoch = last_epoch
|
| if config.VALID.FIRST_VALID:
|
| start_epoch = start_epoch -1
|
|
|
| for epoch in range(start_epoch, end_epoch):
|
|
|
| if epoch>=last_epoch:
|
| train_dataset.shuffle()
|
|
|
| print(f'TRAINING epoch {epoch}:')
|
| train(epoch, config.TRAIN.END_EPOCH,
|
| epoch_iters, config.TRAIN.LR, num_iters,
|
| trainloader, optimizer, model, writer_dict,
|
| adjust_learning_rate=adjust_learning_rate)
|
|
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
| time.sleep(1.0)
|
|
|
| logger.info('=> saving checkpoint to {}'.format(
|
| os.path.join(final_output_dir, 'checkpoint.pth.tar')))
|
| torch.save({
|
| 'epoch': epoch + 1,
|
| 'best_value': best_value,
|
| 'best_key': best_key,
|
| 'state_dict': model.model.module.state_dict(),
|
| 'optimizer': optimizer.state_dict(),
|
| }, os.path.join(final_output_dir, 'checkpoint.pth.tar'))
|
|
|
|
|
|
|
| print(f'VALIDATION epoch {epoch}:')
|
| writer_dict['valid_global_steps'] = epoch
|
|
|
| value_valid, IoU_array, confusion_matrix = \
|
| validate(config, validloader, model, writer_dict, "valid")
|
|
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
| time.sleep(3.0)
|
|
|
| if 'loss' in best_key:
|
| if value_valid[best_key] < best_value:
|
| best_value = value_valid[best_key]
|
| torch.save({
|
| 'epoch': epoch + 1,
|
| 'best_value': best_value,
|
| 'best_key': best_key,
|
| 'state_dict': model.model.module.state_dict(),
|
| 'optimizer': optimizer.state_dict(),
|
| }, os.path.join(final_output_dir, 'best.pth.tar'))
|
| logger.info("best.pth.tar updated.")
|
|
|
| elif value_valid[best_key] > best_value:
|
| best_value = value_valid[best_key]
|
| torch.save({
|
| 'epoch': epoch + 1,
|
| 'best_value': best_value,
|
| 'best_key': best_key,
|
| 'state_dict': model.model.module.state_dict(),
|
| 'optimizer': optimizer.state_dict(),
|
| }, os.path.join(final_output_dir, 'best.pth.tar'))
|
| logger.info("best.pth.tar updated.")
|
|
|
| msg = '(Valid) Loss: {:.3f}, Best_{:s}: {: 4.4f}'.format(
|
| value_valid['loss'], best_key, best_value)
|
| logging.info(msg)
|
| logging.info(IoU_array)
|
| logging.info("confusion_matrix:")
|
| logging.info(confusion_matrix)
|
|
|
|
|
|
|
|
|
| if __name__ == '__main__':
|
| main()
|
|
|