|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch import optim |
|
|
from torch.utils.data import DataLoader, ConcatDataset |
|
|
|
|
|
from models.network.crm import CRMNet |
|
|
from models.sobel_op import SobelComputer |
|
|
from dataset import OnlineTransformDataset_crm as OnlineTransformDataset |
|
|
from util.logger import BoardLogger |
|
|
from util.model_saver import ModelSaver |
|
|
from util.hyper_para import HyperParameters |
|
|
from util.log_integrator import Integrator |
|
|
from util.metrics_compute_crm import compute_loss_and_metrics, iou_hooks_to_be_used |
|
|
from util.image_saver_crm import vis_prediction |
|
|
|
|
|
import time |
|
|
import os |
|
|
import datetime |
|
|
|
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
|
|
|
para = HyperParameters() |
|
|
para.parse() |
|
|
|
|
|
|
|
|
if para['id'].lower() != 'null': |
|
|
long_id = '%s_%s' % (para['id'],datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')) |
|
|
else: |
|
|
long_id = None |
|
|
logger = BoardLogger(long_id) |
|
|
logger.log_string('hyperpara', str(para)) |
|
|
|
|
|
print('CUDA Device count: ', torch.cuda.device_count()) |
|
|
|
|
|
|
|
|
model = CRMNet(backend='resnet50') |
|
|
model = nn.DataParallel( |
|
|
model.cuda(), device_ids=[0,1] |
|
|
) |
|
|
|
|
|
if para['load'] is not None: |
|
|
model.load_state_dict(torch.load(para['load'])) |
|
|
optimizer = optim.Adam(model.parameters(), lr=para['lr'], weight_decay=para['weight_decay']) |
|
|
|
|
|
|
|
|
duts_tr_dir = os.path.join('data', 'DUTS-TR') |
|
|
duts_te_dir = os.path.join('data', 'DUTS-TE') |
|
|
ecssd_dir = os.path.join('data', 'ecssd') |
|
|
msra_dir = os.path.join('data', 'MSRA_10K') |
|
|
|
|
|
fss_dataset = OnlineTransformDataset(os.path.join('data', 'fss'), method=0, perturb=True) |
|
|
duts_tr_dataset = OnlineTransformDataset(duts_tr_dir, method=1, perturb=True) |
|
|
duts_te_dataset = OnlineTransformDataset(duts_te_dir, method=1, perturb=True) |
|
|
ecssd_dataset = OnlineTransformDataset(ecssd_dir, method=1, perturb=True) |
|
|
msra_dataset = OnlineTransformDataset(msra_dir, method=1, perturb=True) |
|
|
|
|
|
print('FSS dataset size: ', len(fss_dataset)) |
|
|
print('DUTS-TR dataset size: ', len(duts_tr_dataset)) |
|
|
print('DUTS-TE dataset size: ', len(duts_te_dataset)) |
|
|
print('ECSSD dataset size: ', len(ecssd_dataset)) |
|
|
print('MSRA-10K dataset size: ', len(msra_dataset)) |
|
|
|
|
|
train_dataset = ConcatDataset([fss_dataset, duts_tr_dataset, duts_te_dataset, ecssd_dataset, msra_dataset]) |
|
|
|
|
|
print('Total training size: ', len(train_dataset)) |
|
|
|
|
|
|
|
|
def worker_init_fn(worker_id): |
|
|
np.random.seed(np.random.get_state()[1][0] + worker_id) |
|
|
|
|
|
|
|
|
train_loader = DataLoader(train_dataset, para['batch_size'], shuffle=True, num_workers=8, |
|
|
worker_init_fn=worker_init_fn, drop_last=True, pin_memory=True) |
|
|
|
|
|
sobel_compute = SobelComputer() |
|
|
|
|
|
|
|
|
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, para['steps'], para['gamma']) |
|
|
|
|
|
saver = ModelSaver(long_id) |
|
|
report_interval = 50 |
|
|
save_im_interval = 800 |
|
|
memory_chunk = 50176 |
|
|
|
|
|
total_epoch = int(para['iterations']/len(train_loader) + 0.5) |
|
|
print('Actual training epoch: ', total_epoch) |
|
|
|
|
|
train_integrator = Integrator(logger) |
|
|
train_integrator.add_hook(iou_hooks_to_be_used) |
|
|
total_iter = 0 |
|
|
last_time = 0 |
|
|
for e in range(total_epoch): |
|
|
np.random.seed() |
|
|
epoch_start_time = time.time() |
|
|
|
|
|
|
|
|
model = model.train() |
|
|
for im, seg, gt, crm_data in train_loader: |
|
|
im, seg, gt = im.cuda(), seg.cuda(), gt.cuda() |
|
|
for k, v in crm_data.items(): |
|
|
crm_data[k] = v.cuda() |
|
|
|
|
|
total_iter += 1 |
|
|
if total_iter % 5000 == 0: |
|
|
saver.save_model(model, total_iter) |
|
|
|
|
|
images = {} |
|
|
for i in range(0, seg.shape[-2]*seg.shape[-1], memory_chunk): |
|
|
chunk_images = model(im, seg, coord=crm_data['coord'][:, i:i+memory_chunk, :], cell=crm_data['cell'][:, i:i+memory_chunk, :]) |
|
|
if 'pred_224' not in images.keys(): |
|
|
images = chunk_images |
|
|
else: |
|
|
for key in images.keys(): |
|
|
images[key] = torch.cat((images[key], chunk_images[key]), axis=1) |
|
|
for key in images.keys(): |
|
|
images[key] = images[key].view(images[key].shape[0], images[key].shape[1]//(seg.shape[-2]*seg.shape[-1]), *seg.shape[-2:]) |
|
|
|
|
|
images['im'] = im |
|
|
images['seg'] = seg |
|
|
images['gt'] = gt |
|
|
sobel_compute.compute_edges(images) |
|
|
|
|
|
loss_and_metrics = compute_loss_and_metrics(images, para) |
|
|
train_integrator.add_dict(loss_and_metrics) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
(loss_and_metrics['total_loss']).backward() |
|
|
optimizer.step() |
|
|
|
|
|
if total_iter % report_interval == 0: |
|
|
logger.log_scalar('train/lr', scheduler.get_lr()[0], total_iter) |
|
|
train_integrator.finalize('train', total_iter) |
|
|
train_integrator.reset_except_hooks() |
|
|
|
|
|
|
|
|
scheduler.step() |
|
|
|
|
|
if total_iter % save_im_interval == 0: |
|
|
predict_vis = vis_prediction(images) |
|
|
logger.log_cv2('train/predict', predict_vis, total_iter) |
|
|
|
|
|
|
|
|
saver.save_model(model, total_iter) |
|
|
|