| import argparse |
| import datetime |
| import os |
| import traceback |
|
|
| import kornia |
| import numpy as np |
| import torch |
| from torch import nn |
| from torch.utils.data import DataLoader |
| from tqdm.autonotebook import tqdm |
|
|
| import models |
| from datasets import MEFDataset |
| from models import PSNR, SSIM, CosineLR |
| from tools import SingleSummaryWriter |
| from tools import saver, mutils |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser('Breaking Downing the Darkness') |
| parser.add_argument('--num_gpus', type=int, default=1, help='number of gpus being used') |
| parser.add_argument('--num_workers', type=int, default=12, help='num_workers of dataloader') |
| parser.add_argument('--batch_size', type=int, default=1, help='The number of images per batch among all devices') |
| parser.add_argument('-m', '--model', type=str, default='INet', |
| help='Model Name') |
| parser.add_argument('-ts', '--targets_split', type=str, default='targets', |
| help='dir of targets') |
| parser.add_argument('--comment', type=str, default='default', |
| help='Project comment') |
| parser.add_argument('--graph', action='store_true') |
| parser.add_argument('--scratch', action='store_true') |
| parser.add_argument('--sampling', action='store_true') |
| parser.add_argument('--test_on_start', action='store_true') |
|
|
| parser.add_argument('--lr', type=float, default=0.01) |
| parser.add_argument('--no_sche', action='store_true') |
|
|
| parser.add_argument('--optim', type=str, default='adam', help='select optimizer for training, ' |
| 'suggest using \'admaw\' until the' |
| ' very final stage then switch to \'sgd\'') |
| parser.add_argument('--num_epochs', type=int, default=500) |
| parser.add_argument('--val_interval', type=int, default=1, help='Number of epoches between valing phases') |
| parser.add_argument('--save_interval', type=int, default=500, help='Number of steps between saving') |
| parser.add_argument('--data_path', type=str, default='./data/SICE', |
| help='the root folder of dataset') |
| parser.add_argument('--log_path', type=str, default='logs/') |
| parser.add_argument('--saved_path', type=str, default='logs/') |
| parser.add_argument('-r', '--resize', type=int, default=-1, help='resize of training images') |
| args = parser.parse_args() |
| return args |
|
|
|
|
| def compute_gradient(img): |
| gradx = img[..., 1:, :] - img[..., :-1, :] |
| grady = img[..., 1:] - img[..., :-1] |
| return gradx, grady |
|
|
|
|
| class ModelINet(nn.Module): |
| def __init__(self, model): |
| super().__init__() |
| self.color_loss = models.L1Loss() |
| self.restor_loss = models.MSSSIML1Loss(channels=3) |
| self.model_canet = model(in_channels=4, out_channels=2) |
| self.eps = 1e-2 |
|
|
| def load_weight(self, model, weight_pth): |
| state_dict = torch.load(weight_pth) |
| ret = model.load_state_dict(state_dict, strict=True) |
| print(ret) |
|
|
| def forward(self, image, image_gt, training=True): |
| texture_in, cb_in, cr_in = torch.split(kornia.color.rgb_to_ycbcr(image), 1, dim=1) |
| texture_gt, cb_gt, cr_gt = torch.split(kornia.color.rgb_to_ycbcr(image_gt), 1, dim=1) |
|
|
| colors = self.model_canet(torch.cat([texture_in, cb_in, cr_in, texture_gt], dim=1)) |
| cb, cr = torch.split(colors, 1, dim=1) |
|
|
| color_loss1 = self.color_loss(cb, cb_gt) |
| color_loss2 = self.color_loss(cr, cr_gt) |
|
|
| image_out = kornia.color.ycbcr_to_rgb(torch.cat([texture_gt, cb, cr], dim=1)) |
| restor_loss = self.restor_loss(image_out, image_gt) |
|
|
| psnr = PSNR(image_out, image_gt) |
| ssim = SSIM(image_out, image_gt).item() |
| return image_out, color_loss1, color_loss2, restor_loss, psnr, ssim |
|
|
|
|
| def train(opt): |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed(42) |
| else: |
| torch.manual_seed(42) |
|
|
| |
| timestamp = mutils.get_formatted_time() |
| opt.saved_path = opt.saved_path + f'/{opt.comment}/{timestamp}' |
| opt.log_path = opt.log_path + f'/{opt.comment}/{timestamp}/tensorboard/' |
| os.makedirs(opt.log_path, exist_ok=True) |
| os.makedirs(opt.saved_path, exist_ok=True) |
|
|
| training_params = {'batch_size': opt.batch_size, |
| 'shuffle': True, |
| 'drop_last': True, |
| 'num_workers': opt.num_workers} |
|
|
| val_params = {'batch_size': 1, |
| 'shuffle': False, |
| 'drop_last': True, |
| 'num_workers': opt.num_workers} |
|
|
| training_set = MEFDataset(os.path.join(opt.data_path, 'train')) |
| training_generator = DataLoader(training_set, **training_params) |
|
|
| val_set = MEFDataset(os.path.join(opt.data_path, 'eval')) |
| val_generator = DataLoader(val_set, **val_params) |
|
|
| model = getattr(models, opt.model) |
|
|
| model = ModelINet(model) |
| print(model) |
|
|
| writer = SingleSummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/') |
|
|
| if opt.num_gpus > 0: |
| model = model.cuda() |
| if opt.num_gpus > 1: |
| model = nn.DataParallel(model) |
|
|
| if opt.optim == 'adam': |
| optimizer = torch.optim.Adam(model.model_canet.parameters(), opt.lr) |
| else: |
| optimizer = torch.optim.SGD(model.model_canet.parameters(), opt.lr, momentum=0.9, nesterov=True) |
|
|
| scheduler = CosineLR(optimizer, opt.lr, opt.num_epochs) |
| epoch = 0 |
| step = 0 |
| model.model_canet.train() |
|
|
| num_iter_per_epoch = len(training_generator) |
|
|
| try: |
| for epoch in range(opt.num_epochs): |
| last_epoch = step // num_iter_per_epoch |
| if epoch < last_epoch: |
| continue |
|
|
| epoch_loss = [] |
| progress_bar = tqdm(training_generator) |
| if not opt.sampling and not opt.test_on_start: |
| for iter, (data, target, name1, name2) in enumerate(progress_bar): |
| if iter < step - last_epoch * num_iter_per_epoch: |
| progress_bar.update() |
| continue |
| try: |
| if opt.num_gpus == 1: |
| data, target = data.cuda(), target.cuda() |
| optimizer.zero_grad() |
|
|
| image_out, color_loss1, color_loss2, \ |
| restor_loss, psnr, ssim = model(data, target, training=True) |
| loss = color_loss1 + color_loss2 + restor_loss |
|
|
| loss.backward() |
| optimizer.step() |
|
|
| epoch_loss.append(float(loss)) |
|
|
| progress_bar.set_description( |
| 'Step: {}. Epoch: {}/{}. Iteration: {}/{}. color_loss1: {:1.5f}, color_loss2: {:1.5f}, restor_loss: {:1.5f}, psnr: {:.5f}, ssim: {:.5f}'.format( |
| step, epoch, opt.num_epochs, iter + 1, num_iter_per_epoch, |
| color_loss1.item(), color_loss2.item(), |
| restor_loss.item(), psnr, ssim)) |
| writer.add_scalar('Loss/train', loss, step) |
| writer.add_scalar('PSNR/train', psnr, step) |
| writer.add_scalar('SSIM/train', ssim, step) |
|
|
| |
| current_lr = optimizer.param_groups[0]['lr'] |
| writer.add_scalar('learning_rate', current_lr, step) |
|
|
| step += 1 |
|
|
| except Exception as e: |
| print('[Error]', traceback.format_exc()) |
| print(e) |
| continue |
|
|
| if opt.no_sche: |
| scheduler.step() |
|
|
| saver.base_url = os.path.join(opt.saved_path, 'results', '%03d' % epoch) |
|
|
| if epoch % opt.val_interval == 0: |
| model.model_canet.eval() |
| loss_ls = [] |
| psnrs = [] |
| ssims = [] |
|
|
| for iter, (data, target, name1, name2) in enumerate(val_generator): |
| with torch.no_grad(): |
| if opt.num_gpus == 1: |
| data, target = data.cuda(), target.cuda() |
|
|
| image_out, color_loss1, color_loss2, restor_loss, \ |
| psnr, ssim = model(data, target, training=False) |
| saver.save_image(data, name=os.path.splitext(name1[0])[0] + '_im1') |
| saver.save_image(target, name=os.path.splitext(name2[0])[0] + '_im2') |
| saver.save_image(image_out, name=os.path.splitext(name2[0])[0] + '_im2_pred') |
|
|
| loss = restor_loss + color_loss1 + color_loss2 |
| loss_ls.append(loss.item()) |
| psnrs.append(psnr) |
| ssims.append(ssim) |
|
|
| |
| image_out, color_loss1, color_loss2, restor_loss, \ |
| psnr, ssim = model(target, data, training=False) |
| saver.save_image(image_out, name=os.path.splitext(name1[0])[0] + '_im1_pred') |
|
|
| loss = restor_loss + color_loss1 + color_loss2 |
| loss_ls.append(loss.item()) |
| psnrs.append(psnr) |
| ssims.append(ssim) |
|
|
| loss = np.mean(np.array(loss_ls)) |
| psnr = np.mean(np.array(psnrs)) |
| ssim = np.mean(np.array(ssims)) |
|
|
| print( |
| 'Val. Epoch: {}/{}. Loss: {:1.5f}, psnr: {:.5f}, ssim: {:.5f}'.format( |
| epoch, opt.num_epochs, loss, psnr, ssim)) |
| writer.add_scalar('Loss/val', loss, step) |
| writer.add_scalar('PSNR/val', psnr, step) |
| writer.add_scalar('SSIM/val', ssim, step) |
|
|
| save_checkpoint(model, f'{opt.model}_{"%03d" % epoch}_{psnr}_{ssim}_{step}.pth') |
|
|
| model.model_canet.train() |
|
|
| opt.test_on_start = False |
| if opt.sampling: |
| exit(0) |
| except KeyboardInterrupt: |
| save_checkpoint(model, f'{opt.model}_{epoch}_{step}_keyboardInterrupt.pth') |
| writer.close() |
| writer.close() |
|
|
|
|
| def save_checkpoint(model, name): |
| if isinstance(model, nn.DataParallel): |
| torch.save(model.module.model_canet.state_dict(), os.path.join(opt.saved_path, name)) |
| else: |
| torch.save(model.model_canet.state_dict(), os.path.join(opt.saved_path, name)) |
|
|
|
|
| if __name__ == '__main__': |
| opt = get_args() |
| train(opt) |
|
|