File size: 5,192 Bytes
352cafd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, ConcatDataset
from models.network.crm import CRMNet
from models.sobel_op import SobelComputer
from dataset import OnlineTransformDataset_crm as OnlineTransformDataset
from util.logger import BoardLogger
from util.model_saver import ModelSaver
from util.hyper_para import HyperParameters
from util.log_integrator import Integrator
from util.metrics_compute_crm import compute_loss_and_metrics, iou_hooks_to_be_used
from util.image_saver_crm import vis_prediction
import time
import os
import datetime
torch.backends.cudnn.benchmark = True
# Parse command line arguments
para = HyperParameters()
para.parse()
# Logging
if para['id'].lower() != 'null':
long_id = '%s_%s' % (para['id'],datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S'))
else:
long_id = None
logger = BoardLogger(long_id)
logger.log_string('hyperpara', str(para))
print('CUDA Device count: ', torch.cuda.device_count())
# Construct model
model = CRMNet(backend='resnet50')
model = nn.DataParallel(
model.cuda(), device_ids=[0,1]
)
if para['load'] is not None:
model.load_state_dict(torch.load(para['load']))
optimizer = optim.Adam(model.parameters(), lr=para['lr'], weight_decay=para['weight_decay'])
duts_tr_dir = os.path.join('data', 'DUTS-TR')
duts_te_dir = os.path.join('data', 'DUTS-TE')
ecssd_dir = os.path.join('data', 'ecssd')
msra_dir = os.path.join('data', 'MSRA_10K')
fss_dataset = OnlineTransformDataset(os.path.join('data', 'fss'), method=0, perturb=True)
duts_tr_dataset = OnlineTransformDataset(duts_tr_dir, method=1, perturb=True)
duts_te_dataset = OnlineTransformDataset(duts_te_dir, method=1, perturb=True)
ecssd_dataset = OnlineTransformDataset(ecssd_dir, method=1, perturb=True)
msra_dataset = OnlineTransformDataset(msra_dir, method=1, perturb=True)
print('FSS dataset size: ', len(fss_dataset))
print('DUTS-TR dataset size: ', len(duts_tr_dataset))
print('DUTS-TE dataset size: ', len(duts_te_dataset))
print('ECSSD dataset size: ', len(ecssd_dataset))
print('MSRA-10K dataset size: ', len(msra_dataset))
train_dataset = ConcatDataset([fss_dataset, duts_tr_dataset, duts_te_dataset, ecssd_dataset, msra_dataset])
print('Total training size: ', len(train_dataset))
# For randomness: https://github.com/pytorch/pytorch/issues/5059
def worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)
# Dataloaders, multi-process data loading
train_loader = DataLoader(train_dataset, para['batch_size'], shuffle=True, num_workers=8,
worker_init_fn=worker_init_fn, drop_last=True, pin_memory=True)
sobel_compute = SobelComputer()
# Learning rate decay scheduling
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, para['steps'], para['gamma'])
saver = ModelSaver(long_id)
report_interval = 50
save_im_interval = 800
memory_chunk = 50176
total_epoch = int(para['iterations']/len(train_loader) + 0.5)
print('Actual training epoch: ', total_epoch)
train_integrator = Integrator(logger)
train_integrator.add_hook(iou_hooks_to_be_used)
total_iter = 0
last_time = 0
for e in range(total_epoch):
np.random.seed() # reset seed
epoch_start_time = time.time()
# Train loop
model = model.train()
for im, seg, gt, crm_data in train_loader:
im, seg, gt = im.cuda(), seg.cuda(), gt.cuda() # [12, 3, 224, 224] [12, 1, 224, 224] [12, 1, 224, 224]
for k, v in crm_data.items():
crm_data[k] = v.cuda()
total_iter += 1
if total_iter % 5000 == 0:
saver.save_model(model, total_iter)
images = {}
for i in range(0, seg.shape[-2]*seg.shape[-1], memory_chunk):
chunk_images = model(im, seg, coord=crm_data['coord'][:, i:i+memory_chunk, :], cell=crm_data['cell'][:, i:i+memory_chunk, :])
if 'pred_224' not in images.keys():
images = chunk_images
else:
for key in images.keys():
images[key] = torch.cat((images[key], chunk_images[key]), axis=1)
for key in images.keys():
images[key] = images[key].view(images[key].shape[0], images[key].shape[1]//(seg.shape[-2]*seg.shape[-1]), *seg.shape[-2:])
images['im'] = im
images['seg'] = seg
images['gt'] = gt
sobel_compute.compute_edges(images)
loss_and_metrics = compute_loss_and_metrics(images, para)
train_integrator.add_dict(loss_and_metrics)
optimizer.zero_grad()
(loss_and_metrics['total_loss']).backward()
optimizer.step()
if total_iter % report_interval == 0:
logger.log_scalar('train/lr', scheduler.get_lr()[0], total_iter)
train_integrator.finalize('train', total_iter)
train_integrator.reset_except_hooks()
# Need to put step AFTER get_lr() for correct logging, see issue #22107 in PyTorch
scheduler.step()
if total_iter % save_im_interval == 0:
predict_vis = vis_prediction(images)
logger.log_cv2('train/predict', predict_vis, total_iter)
# Final save!
saver.save_model(model, total_iter)
|