DaniilOr's picture
Upload folder using huggingface_hub
5f0437a verified
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA').
#
# All rights reserved.
# This work should only be used for nonprofit purposes.
#
# By downloading and/or using any of these files, you implicitly agree to all the
# terms of the license, as specified in the document LICENSE.txt
# (included in this package) and online at
# http://www.grip.unina.it/download/LICENSE_OPEN.txt
"""
Created in September 2022
@author: fabrizio.guillaro
"""
import sys, os
path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
if path not in sys.path:
sys.path.insert(0, path)
import argparse
import logging
import time
import timeit
import gc
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.optim
torch.autograd.set_detect_anomaly(True)
from tensorboardX import SummaryWriter
from lib.config import config, update_config
from lib.core.function import train, validate
from lib.utils import get_model, get_optimizer
from lib.utils import create_logger, FullModel, adjust_learning_rate
from dataset.data_core import myDataset
import albumentations
def main():
parser = argparse.ArgumentParser(description='Train TruFor')
parser.add_argument('-exp', '--experiment', type=str)
parser.add_argument('-g', '--gpu', type=int, default=[0], nargs="+", help='device(s)')
parser.add_argument('opts', help='other options', default=None, nargs=argparse.REMAINDER)
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu)
args.gpu = range(len(args.gpu))
update_config(config, args)
logger, final_output_dir, tb_log_dir = create_logger(config, f'{args.experiment}', 'train')
logger.info(config)
logger.info('\n')
# cudnn setting
cudnn.benchmark = config.CUDNN.BENCHMARK
cudnn.deterministic = config.CUDNN.DETERMINISTIC
cudnn.enabled = config.CUDNN.ENABLED
gpus = list(config.GPUS)
writer_dict = {
'writer': SummaryWriter(tb_log_dir),
'train_global_steps': 0,
'valid_global_steps': 0,
}
if config.TRAIN.AUG is not None:
aug_train = albumentations.load(config.TRAIN.AUG, data_format='yaml')
else:
aug_train = None
if config.VALID.AUG is not None:
aug_valid = albumentations.load(config.VALID.AUG, data_format='yaml')
else:
aug_valid = None
logger.info(f'Train augmentation: {config.TRAIN.AUG} {aug_train}')
logger.info(f'Validation augmentation: {config.VALID.AUG} {aug_valid}')
crop_size = (config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
train_dataset = myDataset(config, crop_size=crop_size, grid_crop=False, mode='train', aug=aug_train)
valid_dataset = myDataset(config, crop_size=None, grid_crop=False, mode="valid", aug=aug_valid,
max_dim=config.VALID.MAX_SIZE)
trainloader = torch.utils.data.DataLoader(
train_dataset,
batch_size = config.TRAIN.BATCH_SIZE_PER_GPU*len(gpus),
shuffle = config.TRAIN.SHUFFLE,
num_workers = config.WORKERS)
validloader = torch.utils.data.DataLoader(
valid_dataset,
batch_size = 1, # 1 to allow arbitrary input sizes
shuffle = False, # must be False to get accurate filename
num_workers = config.WORKERS)
# model
model = get_model(config)
model = torch.nn.DataParallel(model, device_ids=gpus).cuda()
model = FullModel(model, config)
# optimizer
optimizer = get_optimizer(model, config)
epoch_iters = np.int32(train_dataset.__len__() / config.TRAIN.BATCH_SIZE_PER_GPU / len(gpus))
best_key = config.VALID.BEST_KEY
if 'loss' in best_key:
best_value = np.inf
else:
best_value = 0
logger.info(f'best valid key: {best_key}')
last_epoch = 0
if not config.TRAIN.PRETRAINING == '' and not config.TRAIN.PRETRAINING == None:
model_state_file = config.TRAIN.PRETRAINING
assert os.path.isfile(model_state_file)
checkpoint = torch.load(model_state_file, map_location=lambda storage, loc: storage)
state_dict = checkpoint['state_dict']
try:
model.model.module.load_state_dict(state_dict, strict=False)
except:
state_dict = {k: state_dict[k] for k in state_dict if not k.startswith('detection')}
model.model.module.load_state_dict(state_dict, strict=False)
del checkpoint
del state_dict
logger.info("=> loaded pretraining ({})".format(model_state_file))
if config.TRAIN.RESUME:
model_state_file = os.path.join(final_output_dir, 'checkpoint.pth.tar')
if os.path.isfile(model_state_file):
checkpoint = torch.load(model_state_file, map_location=lambda storage, loc: storage)
best_value = checkpoint['best_value']
assert checkpoint['best_key']==best_key
last_epoch = checkpoint['epoch']
model.model.module.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
logger.info("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
writer_dict['train_global_steps'] = last_epoch
else:
logger.info("No previous checkpoint.")
end_epoch = config.TRAIN.END_EPOCH + config.TRAIN.EXTRA_EPOCH
num_iters = config.TRAIN.END_EPOCH * epoch_iters
start_epoch = last_epoch
if config.VALID.FIRST_VALID:
start_epoch = start_epoch -1
for epoch in range(start_epoch, end_epoch):
# train
if epoch>=last_epoch:
train_dataset.shuffle() # for class-balanced sampling
print(f'TRAINING epoch {epoch}:')
train(epoch, config.TRAIN.END_EPOCH,
epoch_iters, config.TRAIN.LR, num_iters,
trainloader, optimizer, model, writer_dict,
adjust_learning_rate=adjust_learning_rate)
torch.cuda.empty_cache()
gc.collect()
time.sleep(1.0)
logger.info('=> saving checkpoint to {}'.format(
os.path.join(final_output_dir, 'checkpoint.pth.tar')))
torch.save({
'epoch': epoch + 1,
'best_value': best_value,
'best_key': best_key,
'state_dict': model.model.module.state_dict(),
'optimizer': optimizer.state_dict(),
}, os.path.join(final_output_dir, 'checkpoint.pth.tar'))
# valid
print(f'VALIDATION epoch {epoch}:')
writer_dict['valid_global_steps'] = epoch
value_valid, IoU_array, confusion_matrix = \
validate(config, validloader, model, writer_dict, "valid")
torch.cuda.empty_cache()
gc.collect()
time.sleep(3.0)
if 'loss' in best_key:
if value_valid[best_key] < best_value: # smallest loss
best_value = value_valid[best_key]
torch.save({
'epoch': epoch + 1,
'best_value': best_value,
'best_key': best_key,
'state_dict': model.model.module.state_dict(),
'optimizer': optimizer.state_dict(),
}, os.path.join(final_output_dir, 'best.pth.tar'))
logger.info("best.pth.tar updated.")
elif value_valid[best_key] > best_value: # highest metric
best_value = value_valid[best_key]
torch.save({
'epoch': epoch + 1,
'best_value': best_value,
'best_key': best_key,
'state_dict': model.model.module.state_dict(),
'optimizer': optimizer.state_dict(),
}, os.path.join(final_output_dir, 'best.pth.tar'))
logger.info("best.pth.tar updated.")
msg = '(Valid) Loss: {:.3f}, Best_{:s}: {: 4.4f}'.format(
value_valid['loss'], best_key, best_value)
logging.info(msg)
logging.info(IoU_array)
logging.info("confusion_matrix:")
logging.info(confusion_matrix)
if __name__ == '__main__':
main()