|
|
import os |
|
|
import time |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
import torch |
|
|
from torch.autograd import Variable |
|
|
|
|
|
import torchtask |
|
|
from torchtask.utils import logger, cmd, tool |
|
|
from torchtask.nn import func |
|
|
|
|
|
|
|
|
def add_parser_arguments(parser): |
|
|
torchtask.trainer_template.add_parser_arguments(parser) |
|
|
|
|
|
|
|
|
def harmonizer_trainer(args, model_dict, optimizer_dict, lrer_dict, criterion_dict, task_func): |
|
|
model_funcs = [model_dict['model']] |
|
|
optimizer_funcs = [optimizer_dict['model']] |
|
|
lrer_funcs = [lrer_dict['model']] |
|
|
criterion_funcs = [criterion_dict['model']] |
|
|
|
|
|
algorithm = HarmonizerTrainer(args) |
|
|
algorithm.build(model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func) |
|
|
return algorithm |
|
|
|
|
|
|
|
|
class HarmonizerTrainer(torchtask.trainer_template.TaskTrainer): |
|
|
def __init__(self, args): |
|
|
super(HarmonizerTrainer, self).__init__(args) |
|
|
|
|
|
self.model = None |
|
|
self.optimizer = None |
|
|
self.lrer = None |
|
|
self.criterion = None |
|
|
|
|
|
def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func): |
|
|
self.task_func = task_func |
|
|
|
|
|
self.model = func.create_model(model_funcs[0], 'model', args=self.args) |
|
|
self.models = {'model': self.model} |
|
|
|
|
|
self.optimizer = optimizer_funcs[0](self.model.module.param_groups) |
|
|
self.optimizers = {'optimizer': self.optimizer} |
|
|
|
|
|
self.lrer = lrer_funcs[0](self.optimizer) |
|
|
self.lrers = {'lrer': self.lrer} |
|
|
|
|
|
self.criterion = criterion_funcs[0](self.args) |
|
|
self.criterions = {'criterion': self.criterion} |
|
|
|
|
|
def _train(self, data_loader, epoch): |
|
|
self.meters.reset() |
|
|
|
|
|
lbs = self.args.labeled_batch_size |
|
|
|
|
|
self.model.train() |
|
|
|
|
|
timer = time.time() |
|
|
for idx, (inp, gt) in enumerate(data_loader): |
|
|
|
|
|
inp, gt = self._batch_prehandle(inp, gt, True) |
|
|
x, mask = inp |
|
|
|
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
resulter, debugger = self.model(inp) |
|
|
|
|
|
pred_outputs = tool.dict_value(resulter, 'outputs') |
|
|
|
|
|
|
|
|
l_pred_outputs = func.split_tensor_tuple(pred_outputs, 0, lbs) |
|
|
l_pred = (l_pred_outputs, ) |
|
|
|
|
|
l_gt = func.split_tensor_tuple(gt, 0, lbs) |
|
|
l_inp = func.split_tensor_tuple(inp, 0, lbs) |
|
|
|
|
|
l_image_losses = self.criterion(l_pred, l_gt, l_inp) |
|
|
|
|
|
|
|
|
sum_losses = l_image_losses[0].detach() |
|
|
for i in range(1, len(l_image_losses)): |
|
|
sum_losses = sum_losses + \ |
|
|
(l_image_losses[i].detach() - l_image_losses[i-1].detach()) * ((l_image_losses[i].detach() - l_image_losses[i-1].detach()) > 0).float() |
|
|
sum_losses = sum_losses + 1e-9 |
|
|
sum_losses = sum_losses.detach() |
|
|
|
|
|
scaled_l_image_losses = [torch.mean(l_image_losses[0] / sum_losses)] |
|
|
self.meters.update('fine_filter_0_loss', torch.mean(l_image_losses[0] / sum_losses).item()) |
|
|
|
|
|
for i in range(1, len(l_image_losses)): |
|
|
loss = (l_image_losses[i] - l_image_losses[i-1].detach()) / sum_losses |
|
|
loss = loss * (loss > 0).float() |
|
|
loss = torch.mean(loss) |
|
|
scaled_l_image_losses.append(loss) |
|
|
self.meters.update('fine_filter_{0}_loss'.format(i), loss.item()) |
|
|
|
|
|
|
|
|
if not self.args.ignore_additional: |
|
|
u_pred_outputs = func.split_tensor_tuple(pred_outputs, lbs, self.args.batch_size) |
|
|
u_pred_outputs = (u_pred_outputs[-1], ) |
|
|
u_pred = (u_pred_outputs, ) |
|
|
|
|
|
u_gt = func.split_tensor_tuple(gt, lbs, self.args.batch_size) |
|
|
u_gt = (u_gt[-1], ) |
|
|
|
|
|
u_inp = func.split_tensor_tuple(inp, lbs, self.args.batch_size) |
|
|
|
|
|
u_image_losses = self.criterion(u_pred, u_gt, u_inp) |
|
|
|
|
|
u_image_loss = torch.mean(u_image_losses[0]) * 10 |
|
|
|
|
|
self.meters.update('coarse_filter_loss', u_image_loss.item()) |
|
|
else: |
|
|
self.meters.update('coarse_filter_loss', torch.mean(torch.zeros(1)).item()) |
|
|
|
|
|
|
|
|
loss = 0 |
|
|
for l_image_loss in scaled_l_image_losses: |
|
|
loss = loss + l_image_loss |
|
|
loss = loss + u_image_loss |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
self.optimizer.step() |
|
|
|
|
|
|
|
|
self.meters.update('batch_time', time.time() - timer) |
|
|
if idx % self.args.log_freq == 0: |
|
|
logger.log_info('step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}'.format(epoch+1, idx, len(data_loader), meters=self.meters)) |
|
|
logger.log_info('\tfine-filter-0-loss: {meters[fine_filter_0_loss]:.6f}'.format(meters=self.meters)) |
|
|
logger.log_info('\tfine-filter-1-loss: {meters[fine_filter_1_loss]:.6f}'.format(meters=self.meters)) |
|
|
logger.log_info('\tfine-filter-2-loss: {meters[fine_filter_2_loss]:.6f}'.format(meters=self.meters)) |
|
|
logger.log_info('\tfine-filter-3-loss: {meters[fine_filter_3_loss]:.6f}'.format(meters=self.meters)) |
|
|
logger.log_info('\tfine-filter-4-loss: {meters[fine_filter_4_loss]:.6f}'.format(meters=self.meters)) |
|
|
logger.log_info('\tfine-filter-5-loss: {meters[fine_filter_5_loss]:.6f}'.format(meters=self.meters)) |
|
|
logger.log_info('\tcoarse-filter-loss: {meters[coarse_filter_loss]:.6f}'.format(meters=self.meters)) |
|
|
|
|
|
|
|
|
if self.args.visualize and idx % self.args.visual_freq == 0: |
|
|
self._visualization( |
|
|
epoch, idx, True, |
|
|
func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), |
|
|
func.split_tensor_tuple(pred_outputs, 0, 1, reduce_dim=True), |
|
|
func.split_tensor_tuple(gt, 0, 1, reduce_dim=True)) |
|
|
|
|
|
|
|
|
if not self.args.is_epoch_lrer: |
|
|
self.lrer.step() |
|
|
|
|
|
timer = time.time() |
|
|
|
|
|
|
|
|
if self.args.is_epoch_lrer: |
|
|
self.lrer.step() |
|
|
|
|
|
def _validate(self, data_loader, epoch): |
|
|
self.meters.reset() |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
timer = time.time() |
|
|
for idx, (inp, gt) in enumerate(data_loader): |
|
|
inp, gt = self._batch_prehandle(inp, gt, False) |
|
|
x, mask = inp |
|
|
|
|
|
resulter, debugger = self.model(inp) |
|
|
|
|
|
pred_outputs = tool.dict_value(resulter, 'outputs') |
|
|
|
|
|
pred = (pred_outputs[-1], ) |
|
|
gt = (gt[-1], ) |
|
|
|
|
|
|
|
|
losses = self.criterion.forward(pred, gt, inp) |
|
|
loss = 0 |
|
|
for _loss in losses: |
|
|
loss = loss + _loss |
|
|
loss = loss / len(losses) |
|
|
|
|
|
self.meters.update('loss', loss.item()) |
|
|
|
|
|
self.task_func.metrics(pred_outputs[-1].detach(), gt[-1], mask, self.meters, id_str='IH') |
|
|
|
|
|
self.meters.update('batch_time', time.time() - timer) |
|
|
if idx % self.args.log_freq == 0: |
|
|
logger.log_info('step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n' |
|
|
'loss: {meters[loss]:.6f}\n' |
|
|
.format(epoch+1, idx, len(data_loader), meters=self.meters)) |
|
|
|
|
|
if self.args.visualize: |
|
|
self._visualization( |
|
|
epoch, idx, False, |
|
|
func.split_tensor_tuple(inp, 0, 1, reduce_dim=True), |
|
|
func.split_tensor_tuple((pred_outputs[-1], ), 0, 1, reduce_dim=True), |
|
|
func.split_tensor_tuple(gt, 0, 1, reduce_dim=True)) |
|
|
|
|
|
timer = time.time() |
|
|
|
|
|
metrics_info = {'IH': ''} |
|
|
for key in sorted(list(self.meters.keys())): |
|
|
if self.task_func.METRIC_STR in key: |
|
|
for id_str in metrics_info.keys(): |
|
|
if key.startswith(id_str): |
|
|
metrics_info[id_str] += '{0}: {1:.6}\t'.format(key, self.meters[key]) |
|
|
|
|
|
logger.log_info('Validation metrics:\n task-metrics\t=>\t{0}\n'.format(metrics_info['IH'].replace('_', '-'))) |
|
|
|
|
|
def _visualization(self, epoch, idx, is_train, inp, pred, gt): |
|
|
visualize_path = self.args.visual_train_path if is_train else self.args.visual_val_path |
|
|
out_path = os.path.join(visualize_path, '{0}_{1}'.format(epoch, idx)) |
|
|
|
|
|
x, mask = inp |
|
|
|
|
|
x = (np.transpose(x.cpu().numpy(), (1, 2, 0))) |
|
|
Image.fromarray((x * 255).astype('uint8')).save(out_path + '_1_0_x.jpg') |
|
|
|
|
|
mask = mask[0].data.cpu().numpy() |
|
|
Image.fromarray((mask * 255).astype('uint8'), mode='L').save(out_path + '_2_0_mask.jpg') |
|
|
|
|
|
for idx, (pred_, gt_) in enumerate(zip(pred, gt)): |
|
|
pred_ = (np.transpose(pred_.detach().cpu().numpy(), (1, 2, 0))) |
|
|
Image.fromarray((pred_ * 255).astype('uint8')).save(out_path + '_1_{0}_pred_filter.jpg'.format(idx+1)) |
|
|
|
|
|
if torch.mean(gt_) != -999: |
|
|
gt_ = (np.transpose(gt_.cpu().numpy(), (1, 2, 0))) |
|
|
Image.fromarray((gt_ * 255).astype('uint8')).save(out_path + '_2_{0}_gt_filter.jpg'.format(idx+1)) |
|
|
|
|
|
def _save_checkpoint(self, epoch): |
|
|
state = { |
|
|
'epoch': epoch, |
|
|
'model': self.model.state_dict(), |
|
|
'optimizer': self.optimizer.state_dict(), |
|
|
'lrer': self.lrer.state_dict(), |
|
|
} |
|
|
checkpoint = os.path.join(self.args.checkpoint_path, 'checkpoint_{0}.ckpt'.format(epoch)) |
|
|
|
|
|
torch.save(state, checkpoint) |
|
|
|
|
|
def _load_checkpoint(self): |
|
|
checkpoint = torch.load(self.args.resume) |
|
|
self.model.load_state_dict(checkpoint['model']) |
|
|
self.optimizer.load_state_dict(checkpoint['optimizer']) |
|
|
self.lrer.load_state_dict(checkpoint['lrer']) |
|
|
return checkpoint['epoch'] |
|
|
|
|
|
def _batch_prehandle(self, inp, gt, is_train): |
|
|
lbs = self.args.labeled_batch_size |
|
|
ubs = self.args.additional_batch_size |
|
|
|
|
|
|
|
|
inp_var = [] |
|
|
for i in inp: |
|
|
inp_var.append(Variable(i).cuda()) |
|
|
inp = tuple(inp_var) |
|
|
|
|
|
gt_var = [] |
|
|
for g in gt: |
|
|
gt_var.append(Variable(g).cuda()) |
|
|
gt = tuple(gt_var) |
|
|
|
|
|
filter_num = len(self.model.module.model.filter_types) |
|
|
|
|
|
if is_train: |
|
|
|
|
|
|
|
|
|
|
|
l_inp = func.split_tensor_tuple(inp, 0, lbs) |
|
|
l_gt = func.split_tensor_tuple(gt, 0, lbs) |
|
|
|
|
|
_, l_mask = l_inp |
|
|
l_gt_image, = l_gt |
|
|
|
|
|
n = l_gt_image.shape[0] |
|
|
l_rand_arguments = [self._rand_adjustment_values(n) for _ in range(0, filter_num)] |
|
|
|
|
|
l_x = self.model.module.adjust(l_gt_image, l_mask, l_rand_arguments) |
|
|
|
|
|
l_inp = (l_x[-1], l_mask) |
|
|
l_gt = [] |
|
|
for _ in reversed(l_x[:-1]): |
|
|
l_gt.append(_) |
|
|
l_gt.append(l_gt_image) |
|
|
|
|
|
if not self.args.ignore_additional: |
|
|
|
|
|
|
|
|
|
|
|
u_inp = func.split_tensor_tuple(inp, lbs, self.args.batch_size) |
|
|
u_gt = func.split_tensor_tuple(gt, lbs, self.args.batch_size) |
|
|
|
|
|
u_gt_image, = u_gt |
|
|
none_value = torch.ones(ubs).view(ubs, 1).cuda() * -999 |
|
|
none_im = u_gt_image.cuda() * 0 - 999 |
|
|
|
|
|
u_gt = [none_im for _ in range(0, filter_num)] |
|
|
u_gt[-1] = u_gt_image |
|
|
|
|
|
inp = func.combine_tensor_tuple(l_inp, u_inp, 0) |
|
|
gt = func.combine_tensor_tuple(l_gt, u_gt, 0) |
|
|
|
|
|
else: |
|
|
inp = l_inp |
|
|
gt = l_gt |
|
|
|
|
|
else: |
|
|
gt_image, = gt |
|
|
|
|
|
none_value = torch.ones(1).view(1, 1).cuda() * -999 |
|
|
none_im = gt_image.cuda() * 0 - 999 |
|
|
|
|
|
gt = [none_im for _ in range(0, filter_num)] |
|
|
gt[-1] = gt_image |
|
|
|
|
|
return inp, gt |
|
|
|
|
|
def _rand_adjustment_values(self, n): |
|
|
x = torch.FloatTensor(np.random.uniform(-1, 1, n)) |
|
|
x = x.view(n, 1).cuda() |
|
|
return x |
|
|
|