|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Created in September 2022 |
|
|
@author: fabrizio.guillaro |
|
|
""" |
|
|
|
|
|
import sys, os |
|
|
path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..') |
|
|
if path not in sys.path: |
|
|
sys.path.insert(0, path) |
|
|
|
|
|
import argparse |
|
|
|
|
|
import logging |
|
|
import time |
|
|
import timeit |
|
|
|
|
|
import gc |
|
|
import numpy as np |
|
|
|
|
|
import torch |
|
|
import torch.backends.cudnn as cudnn |
|
|
import torch.optim |
|
|
torch.autograd.set_detect_anomaly(True) |
|
|
from tensorboardX import SummaryWriter |
|
|
|
|
|
from lib.config import config, update_config |
|
|
from lib.core.function import train, validate |
|
|
from lib.utils import get_model, get_optimizer |
|
|
from lib.utils import create_logger, FullModel, adjust_learning_rate |
|
|
|
|
|
from dataset.data_core import myDataset |
|
|
import albumentations |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description='Train TruFor') |
|
|
parser.add_argument('-exp', '--experiment', type=str) |
|
|
parser.add_argument('-g', '--gpu', type=int, default=[0], nargs="+", help='device(s)') |
|
|
parser.add_argument('opts', help='other options', default=None, nargs=argparse.REMAINDER) |
|
|
args = parser.parse_args() |
|
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu) |
|
|
args.gpu = range(len(args.gpu)) |
|
|
|
|
|
update_config(config, args) |
|
|
|
|
|
logger, final_output_dir, tb_log_dir = create_logger(config, f'{args.experiment}', 'train') |
|
|
logger.info(config) |
|
|
logger.info('\n') |
|
|
|
|
|
|
|
|
cudnn.benchmark = config.CUDNN.BENCHMARK |
|
|
cudnn.deterministic = config.CUDNN.DETERMINISTIC |
|
|
cudnn.enabled = config.CUDNN.ENABLED |
|
|
|
|
|
gpus = list(config.GPUS) |
|
|
|
|
|
writer_dict = { |
|
|
'writer': SummaryWriter(tb_log_dir), |
|
|
'train_global_steps': 0, |
|
|
'valid_global_steps': 0, |
|
|
} |
|
|
|
|
|
if config.TRAIN.AUG is not None: |
|
|
aug_train = albumentations.load(config.TRAIN.AUG, data_format='yaml') |
|
|
else: |
|
|
aug_train = None |
|
|
|
|
|
if config.VALID.AUG is not None: |
|
|
aug_valid = albumentations.load(config.VALID.AUG, data_format='yaml') |
|
|
else: |
|
|
aug_valid = None |
|
|
|
|
|
logger.info(f'Train augmentation: {config.TRAIN.AUG} {aug_train}') |
|
|
logger.info(f'Validation augmentation: {config.VALID.AUG} {aug_valid}') |
|
|
|
|
|
crop_size = (config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0]) |
|
|
train_dataset = myDataset(config, crop_size=crop_size, grid_crop=False, mode='train', aug=aug_train) |
|
|
valid_dataset = myDataset(config, crop_size=None, grid_crop=False, mode="valid", aug=aug_valid, |
|
|
max_dim=config.VALID.MAX_SIZE) |
|
|
|
|
|
trainloader = torch.utils.data.DataLoader( |
|
|
train_dataset, |
|
|
batch_size = config.TRAIN.BATCH_SIZE_PER_GPU*len(gpus), |
|
|
shuffle = config.TRAIN.SHUFFLE, |
|
|
num_workers = config.WORKERS) |
|
|
|
|
|
validloader = torch.utils.data.DataLoader( |
|
|
valid_dataset, |
|
|
batch_size = 1, |
|
|
shuffle = False, |
|
|
num_workers = config.WORKERS) |
|
|
|
|
|
|
|
|
model = get_model(config) |
|
|
model = torch.nn.DataParallel(model, device_ids=gpus).cuda() |
|
|
model = FullModel(model, config) |
|
|
|
|
|
|
|
|
optimizer = get_optimizer(model, config) |
|
|
|
|
|
epoch_iters = np.int32(train_dataset.__len__() / config.TRAIN.BATCH_SIZE_PER_GPU / len(gpus)) |
|
|
|
|
|
best_key = config.VALID.BEST_KEY |
|
|
if 'loss' in best_key: |
|
|
best_value = np.inf |
|
|
else: |
|
|
best_value = 0 |
|
|
logger.info(f'best valid key: {best_key}') |
|
|
|
|
|
|
|
|
last_epoch = 0 |
|
|
if not config.TRAIN.PRETRAINING == '' and not config.TRAIN.PRETRAINING == None: |
|
|
model_state_file = config.TRAIN.PRETRAINING |
|
|
assert os.path.isfile(model_state_file) |
|
|
checkpoint = torch.load(model_state_file, map_location=lambda storage, loc: storage) |
|
|
state_dict = checkpoint['state_dict'] |
|
|
try: |
|
|
model.model.module.load_state_dict(state_dict, strict=False) |
|
|
except: |
|
|
state_dict = {k: state_dict[k] for k in state_dict if not k.startswith('detection')} |
|
|
model.model.module.load_state_dict(state_dict, strict=False) |
|
|
del checkpoint |
|
|
del state_dict |
|
|
logger.info("=> loaded pretraining ({})".format(model_state_file)) |
|
|
|
|
|
|
|
|
if config.TRAIN.RESUME: |
|
|
model_state_file = os.path.join(final_output_dir, 'checkpoint.pth.tar') |
|
|
if os.path.isfile(model_state_file): |
|
|
checkpoint = torch.load(model_state_file, map_location=lambda storage, loc: storage) |
|
|
best_value = checkpoint['best_value'] |
|
|
assert checkpoint['best_key']==best_key |
|
|
last_epoch = checkpoint['epoch'] |
|
|
model.model.module.load_state_dict(checkpoint['state_dict']) |
|
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
|
logger.info("=> loaded checkpoint (epoch {})".format(checkpoint['epoch'])) |
|
|
writer_dict['train_global_steps'] = last_epoch |
|
|
else: |
|
|
logger.info("No previous checkpoint.") |
|
|
|
|
|
|
|
|
end_epoch = config.TRAIN.END_EPOCH + config.TRAIN.EXTRA_EPOCH |
|
|
num_iters = config.TRAIN.END_EPOCH * epoch_iters |
|
|
start_epoch = last_epoch |
|
|
if config.VALID.FIRST_VALID: |
|
|
start_epoch = start_epoch -1 |
|
|
|
|
|
for epoch in range(start_epoch, end_epoch): |
|
|
|
|
|
if epoch>=last_epoch: |
|
|
train_dataset.shuffle() |
|
|
|
|
|
print(f'TRAINING epoch {epoch}:') |
|
|
train(epoch, config.TRAIN.END_EPOCH, |
|
|
epoch_iters, config.TRAIN.LR, num_iters, |
|
|
trainloader, optimizer, model, writer_dict, |
|
|
adjust_learning_rate=adjust_learning_rate) |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
time.sleep(1.0) |
|
|
|
|
|
logger.info('=> saving checkpoint to {}'.format( |
|
|
os.path.join(final_output_dir, 'checkpoint.pth.tar'))) |
|
|
torch.save({ |
|
|
'epoch': epoch + 1, |
|
|
'best_value': best_value, |
|
|
'best_key': best_key, |
|
|
'state_dict': model.model.module.state_dict(), |
|
|
'optimizer': optimizer.state_dict(), |
|
|
}, os.path.join(final_output_dir, 'checkpoint.pth.tar')) |
|
|
|
|
|
|
|
|
|
|
|
print(f'VALIDATION epoch {epoch}:') |
|
|
writer_dict['valid_global_steps'] = epoch |
|
|
|
|
|
value_valid, IoU_array, confusion_matrix = \ |
|
|
validate(config, validloader, model, writer_dict, "valid") |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
time.sleep(3.0) |
|
|
|
|
|
if 'loss' in best_key: |
|
|
if value_valid[best_key] < best_value: |
|
|
best_value = value_valid[best_key] |
|
|
torch.save({ |
|
|
'epoch': epoch + 1, |
|
|
'best_value': best_value, |
|
|
'best_key': best_key, |
|
|
'state_dict': model.model.module.state_dict(), |
|
|
'optimizer': optimizer.state_dict(), |
|
|
}, os.path.join(final_output_dir, 'best.pth.tar')) |
|
|
logger.info("best.pth.tar updated.") |
|
|
|
|
|
elif value_valid[best_key] > best_value: |
|
|
best_value = value_valid[best_key] |
|
|
torch.save({ |
|
|
'epoch': epoch + 1, |
|
|
'best_value': best_value, |
|
|
'best_key': best_key, |
|
|
'state_dict': model.model.module.state_dict(), |
|
|
'optimizer': optimizer.state_dict(), |
|
|
}, os.path.join(final_output_dir, 'best.pth.tar')) |
|
|
logger.info("best.pth.tar updated.") |
|
|
|
|
|
msg = '(Valid) Loss: {:.3f}, Best_{:s}: {: 4.4f}'.format( |
|
|
value_valid['loss'], best_key, best_value) |
|
|
logging.info(msg) |
|
|
logging.info(IoU_array) |
|
|
logging.info("confusion_matrix:") |
|
|
logging.info(confusion_matrix) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|