| | import os |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | import wandb |
| | from torch.utils.data import DataLoader |
| | from torchvision import transforms |
| | from tqdm import tqdm |
| |
|
| | from data.LQGT_dataset import LQGTDataset, LQGTValDataset |
| | from model import decoder, discriminator, encoder |
| | from opt.option import args |
| | from util.utils import (RandCrop, RandHorizontalFlip, RandRotate, ToTensor, RandCrop_pair, |
| | VGG19PerceptualLoss) |
| |
|
| | from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure |
| | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity |
| |
|
| | wandb.init(project='SR', config=args) |
| |
|
| |
|
| |
|
| | |
| | if args.gpu_id is not None: |
| | os.environ['CUDA_VISIBLE_DEVICES'] = "0" |
| | print('using GPU 0') |
| | else: |
| | print('use --gpu_id to specify GPU ID to use') |
| | exit() |
| |
|
| | device = torch.device('cuda') |
| |
|
| | |
| | if not os.path.exists(args.snap_path): |
| | os.mkdir(args.snap_path) |
| |
|
| | print("Loading dataset...") |
| | |
| | train_dataset = LQGTDataset( |
| | db_path=args.dir_data, |
| | transform=transforms.Compose([RandCrop(args.patch_size, args.scale), RandHorizontalFlip(), RandRotate(), ToTensor()]) |
| | ) |
| |
|
| | val_dataset = LQGTValDataset( |
| | db_path=args.dir_data, |
| | transform=transforms.Compose([RandCrop_pair(args.patch_size, args.scale), ToTensor()]) |
| | ) |
| |
|
| | train_loader = DataLoader( |
| | train_dataset, |
| | batch_size=args.batch_size, |
| | num_workers=args.num_workers, |
| | drop_last=True, |
| | shuffle=True |
| | ) |
| |
|
| | val_loader = DataLoader( |
| | val_dataset, |
| | batch_size=args.batch_size, |
| | num_workers=args.num_workers, |
| | shuffle=False |
| | ) |
| |
|
| |
|
| | print("Create model") |
| | model_Disc_feat = discriminator.DiscriminatorVGG(in_ch=args.n_hidden_feats, image_size=args.patch_size).to(device) |
| | model_Disc_img_LR = discriminator.DiscriminatorVGG(in_ch=3, image_size=args.patch_size).to(device) |
| | model_Disc_img_HR = discriminator.DiscriminatorVGG(in_ch=3, image_size=args.scale*args.patch_size).to(device) |
| | |
| | model_Enc = encoder.Encoder_RRDB(num_feat=args.n_hidden_feats).to(device) |
| | model_Dec_Id = decoder.Decoder_Id_RRDB(num_in_ch=args.n_hidden_feats).to(device) |
| | model_Dec_SR = decoder.Decoder_SR_RRDB(num_in_ch=args.n_hidden_feats).to(device) |
| |
|
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | wandb.watch(model_Disc_feat) |
| | wandb.watch(model_Disc_img_LR) |
| | wandb.watch(model_Enc) |
| | wandb.watch(model_Dec_Id) |
| | wandb.watch(model_Dec_SR) |
| |
|
| |
|
| | print("Define Loss") |
| | |
| | loss_L1 = nn.L1Loss().to(device) |
| | loss_MSE = nn.MSELoss().to(device) |
| | loss_adversarial = nn.BCEWithLogitsLoss().to(device) |
| | loss_percept = VGG19PerceptualLoss().to(device) |
| |
|
| |
|
| | print("Define Optimizer") |
| | |
| | params_G = list(model_Enc.parameters()) + list(model_Dec_Id.parameters()) + list(model_Dec_SR.parameters()) |
| | optimizer_G = optim.Adam( |
| | params_G, |
| | lr=args.lr_G, |
| | betas=(args.beta1, args.beta2), |
| | weight_decay=args.weight_decay, |
| | amsgrad=True |
| | ) |
| | params_D = list(model_Disc_feat.parameters()) + list(model_Disc_img_LR.parameters()) + list(model_Disc_img_HR.parameters()) |
| | optimizer_D = optim.Adam( |
| | params_D, |
| | lr=args.lr_D, |
| | betas=(args.beta1, args.beta2), |
| | weight_decay=args.weight_decay, |
| | amsgrad=True |
| | ) |
| |
|
| | print("Define Scheduler") |
| | |
| | iter_indices = [args.interval1, args.interval2, args.interval3] |
| | scheduler_G = optim.lr_scheduler.MultiStepLR( |
| | optimizer=optimizer_G, |
| | milestones=iter_indices, |
| | gamma=0.5 |
| | ) |
| | scheduler_D = optim.lr_scheduler.MultiStepLR( |
| | optimizer=optimizer_D, |
| | milestones=iter_indices, |
| | gamma=0.5 |
| | ) |
| |
|
| | |
| | model_Enc = nn.DataParallel(model_Enc) |
| | model_Dec_Id = nn.DataParallel(model_Dec_Id) |
| | model_Dec_SR = nn.DataParallel(model_Dec_SR) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | print("Load model weight") |
| | |
| | if args.checkpoint is not None: |
| | checkpoint = torch.load(args.checkpoint) |
| |
|
| | model_Enc.load_state_dict(checkpoint['model_Enc']) |
| | model_Dec_Id.load_state_dict(checkpoint['model_Dec_Id']) |
| | model_Dec_SR.load_state_dict(checkpoint['model_Dec_SR']) |
| | model_Disc_feat.load_state_dict(checkpoint['model_Disc_feat']) |
| | model_Disc_img_LR.load_state_dict(checkpoint['model_Disc_img_LR']) |
| | model_Disc_img_HR.load_state_dict(checkpoint['model_Disc_img_HR']) |
| |
|
| | optimizer_D.load_state_dict(checkpoint['optimizer_D']) |
| | optimizer_G.load_state_dict(checkpoint['optimizer_G']) |
| |
|
| | scheduler_D.load_state_dict(checkpoint['scheduler_D']) |
| | scheduler_G.load_state_dict(checkpoint['scheduler_G']) |
| |
|
| | start_epoch = checkpoint['epoch'] |
| | else: |
| | start_epoch = 0 |
| |
|
| |
|
| | if args.pretrained is not None: |
| | ckpt = torch.load(args.pretrained) |
| | ckpt["params"]["conv_first.weight"] = ckpt["params"]["conv_first.weight"][:,0,:,:].expand(64,64,3,3) |
| | model_Dec_SR.load_state_dict(ckpt["params"]) |
| | |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | PSNR = PeakSignalNoiseRatio().to(device) |
| | SSIM = StructuralSimilarityIndexMeasure().to(device) |
| | LPIPS = LearnedPerceptualImagePatchSimilarity().to(device) |
| |
|
| | if args.phase == "train": |
| | for epoch in range(start_epoch, args.epochs): |
| | |
| | model_Enc.train() |
| | model_Dec_Id.train() |
| | model_Dec_SR.train() |
| |
|
| | |
| | model_Disc_feat.train() |
| | model_Disc_img_LR.train() |
| | model_Disc_img_HR.train() |
| | |
| | running_loss_D_total = 0.0 |
| | running_loss_G_total = 0.0 |
| |
|
| | running_loss_align = 0.0 |
| | running_loss_rec = 0.0 |
| | running_loss_res = 0.0 |
| | running_loss_sty = 0.0 |
| | running_loss_idt = 0.0 |
| | running_loss_cyc = 0.0 |
| |
|
| | iter = 0 |
| |
|
| | for data in tqdm(train_loader): |
| | iter += 1 |
| |
|
| | |
| | |
| | |
| | X_t, Y_s = data['img_LQ'], data['img_GT'] |
| |
|
| | ds4 = nn.Upsample(scale_factor=1/args.scale, mode='bicubic') |
| | X_s = ds4(Y_s) |
| |
|
| | X_t = X_t.cuda(non_blocking=True) |
| | X_s = X_s.cuda(non_blocking=True) |
| | Y_s = Y_s.cuda(non_blocking=True) |
| |
|
| | |
| | batch_size = X_t.size(0) |
| | real_label = torch.full((batch_size, 1), 1, dtype=X_t.dtype).cuda(non_blocking=True) |
| | fake_label = torch.full((batch_size, 1), 0, dtype=X_t.dtype).cuda(non_blocking=True) |
| |
|
| |
|
| | |
| | |
| | |
| | model_Disc_feat.zero_grad() |
| | model_Disc_img_LR.zero_grad() |
| | model_Disc_img_HR.zero_grad() |
| |
|
| | for i in range(args.n_disc): |
| | |
| | F_t = model_Enc(X_t) |
| | F_s = model_Enc(X_s) |
| |
|
| | |
| | |
| | output_Disc_F_t = model_Disc_feat(F_t.detach()) |
| | output_Disc_F_s = model_Disc_feat(F_s.detach()) |
| | |
| | loss_Disc_F_t = loss_MSE(output_Disc_F_t, fake_label) |
| | loss_Disc_F_s = loss_MSE(output_Disc_F_s, real_label) |
| | loss_Disc_feat_align = (loss_Disc_F_t + loss_Disc_F_s) / 2 |
| |
|
| | |
| | |
| | Y_s_s = model_Dec_SR(F_s) |
| | |
| | output_Disc_Y_s_s = model_Disc_img_HR(Y_s_s.detach()) |
| | output_Disc_Y_s = model_Disc_img_HR(Y_s) |
| | |
| | loss_Disc_Y_s_s = loss_MSE(output_Disc_Y_s_s, fake_label) |
| | loss_Disc_Y_s = loss_MSE(output_Disc_Y_s, real_label) |
| | loss_Disc_img_rec = (loss_Disc_Y_s_s + loss_Disc_Y_s) / 2 |
| |
|
| | |
| | |
| | X_s_t = model_Dec_Id(F_s) |
| | |
| | output_Disc_X_s_t = model_Disc_img_LR(X_s_t.detach()) |
| | output_Disc_X_t = model_Disc_img_LR(X_t) |
| | |
| | loss_Disc_X_s_t = loss_MSE(output_Disc_X_s_t, fake_label) |
| | loss_Disc_X_t = loss_MSE(output_Disc_X_t, real_label) |
| | loss_Disc_img_sty = (loss_Disc_X_s_t + loss_Disc_X_t) / 2 |
| |
|
| | |
| | |
| | Y_s_t_s = model_Dec_SR(model_Enc(model_Dec_Id(F_s))) |
| | |
| | output_Disc_Y_s_t_s = model_Disc_img_HR(Y_s_t_s.detach()) |
| | output_Disc_Y_s = model_Disc_img_HR(Y_s) |
| | |
| | loss_Disc_Y_s_t_s = loss_MSE(output_Disc_Y_s_t_s, fake_label) |
| | loss_Disc_Y_s = loss_MSE(output_Disc_Y_s, real_label) |
| | loss_Disc_img_cyc = (loss_Disc_Y_s_t_s + loss_Disc_Y_s) / 2 |
| |
|
| | |
| | loss_D_total = loss_Disc_feat_align + loss_Disc_img_rec + loss_Disc_img_sty + loss_Disc_img_cyc |
| | loss_D_total.backward() |
| | optimizer_D.step() |
| |
|
| |
|
| |
|
| | scheduler_D.step() |
| |
|
| |
|
| | |
| | |
| | |
| | model_Enc.zero_grad() |
| | model_Dec_Id.zero_grad() |
| | model_Dec_SR.zero_grad() |
| |
|
| | for i in range(args.n_gen): |
| | |
| | F_t = model_Enc(X_t) |
| | F_s = model_Enc(X_s) |
| |
|
| | |
| | |
| | output_Disc_F_t = model_Disc_feat(F_t) |
| | output_Disc_F_s = model_Disc_feat(F_s) |
| | |
| | loss_G_F_t = loss_MSE(output_Disc_F_t, (real_label + fake_label)/2) |
| | loss_G_F_s = loss_MSE(output_Disc_F_s, (real_label + fake_label)/2) |
| | L_align_E = loss_G_F_t + loss_G_F_s |
| |
|
| | |
| | |
| | Y_s_s = model_Dec_SR(F_s) |
| | |
| | output_Disc_Y_s_s = model_Disc_img_HR(Y_s_s) |
| | |
| | loss_L1_rec = loss_L1(Y_s.detach(), Y_s_s) |
| | |
| | loss_percept_rec = loss_percept(Y_s.detach(), Y_s_s) |
| | |
| | loss_G_Y_s_s = loss_MSE(output_Disc_Y_s_s, real_label) |
| | L_rec_G_SR = loss_L1_rec + args.lambda_percept*loss_percept_rec + args.lambda_adv*loss_G_Y_s_s |
| |
|
| | |
| | X_t_t = model_Dec_Id(F_t) |
| | L_res_G_t = loss_L1(X_t, X_t_t) |
| |
|
| | |
| | |
| | X_s_t = model_Dec_Id(F_s) |
| | |
| | output_Disc_X_s_t = model_Disc_img_LR(X_s_t) |
| | |
| | loss_G_X_s_t = loss_MSE(output_Disc_X_s_t, real_label) |
| | L_sty_G_t = loss_G_X_s_t |
| |
|
| | |
| | F_s_tilda = model_Enc(model_Dec_Id(F_s)) |
| | L_idt_G_t = loss_L1(F_s, F_s_tilda) |
| |
|
| | |
| | |
| | Y_s_t_s = model_Dec_SR(model_Enc(model_Dec_Id(F_s))) |
| | |
| | output_Disc_Y_s_t_s = model_Disc_img_HR(Y_s_t_s) |
| | |
| | loss_L1_cyc = loss_L1(Y_s.detach(), Y_s_t_s) |
| | |
| | loss_percept_cyc = loss_percept(Y_s.detach(), Y_s_t_s) |
| | |
| | loss_Y_s_t_s = loss_MSE(output_Disc_Y_s_t_s, real_label) |
| | L_cyc_G_t_G_SR = loss_L1_cyc + args.lambda_percept*loss_percept_cyc + args.lambda_adv*loss_Y_s_t_s |
| |
|
| | |
| | loss_G_total = args.lambda_align*L_align_E + args.lambda_rec*L_rec_G_SR + args.lambda_res*L_res_G_t + args.lambda_sty*L_sty_G_t + args.lambda_idt*L_idt_G_t + args.lambda_cyc*L_cyc_G_t_G_SR |
| | loss_G_total.backward() |
| | optimizer_G.step() |
| | scheduler_G.step() |
| |
|
| |
|
| | |
| | |
| | |
| | running_loss_D_total += loss_D_total.item() |
| | running_loss_G_total += loss_G_total.item() |
| |
|
| | running_loss_align += L_align_E.item() |
| | running_loss_rec += L_rec_G_SR.item() |
| | running_loss_res += L_res_G_t.item() |
| | running_loss_sty += L_sty_G_t.item() |
| | running_loss_idt += L_idt_G_t.item() |
| | running_loss_cyc += L_cyc_G_t_G_SR.item() |
| | if iter % args.log_interval == 0: |
| | wandb.log( |
| | { |
| | "loss_D_total_step": running_loss_D_total/iter, |
| | "loss_G_total_step": running_loss_G_total/iter, |
| | "loss_align_step": running_loss_align/iter, |
| | "loss_rec_step": running_loss_rec/iter, |
| | "loss_res_step": running_loss_res/iter, |
| | "loss_sty_step": running_loss_sty/iter, |
| | "loss_idt_step": running_loss_idt/iter, |
| | "loss_cyc_step": running_loss_cyc/iter, |
| | } |
| | ) |
| | |
| | total_PSNR = 0 |
| | total_SSIM = 0 |
| | total_LPIPS = 0 |
| | val_iter = 0 |
| | with torch.no_grad(): |
| | model_Enc.eval() |
| | model_Dec_SR.eval() |
| | for batch_idx, batch in enumerate(val_loader): |
| | val_iter += 1 |
| | source = batch["img_LQ"].to(device) |
| | target = batch["img_GT"].to(device) |
| |
|
| | feat = model_Enc(source) |
| | out = model_Dec_SR(feat) |
| |
|
| | total_PSNR += PSNR(out, target) |
| | total_SSIM += SSIM(out, target) |
| | total_LPIPS += LPIPS(out, target) |
| | |
| | wandb.log( |
| | { |
| | "epoch": epoch, |
| | "lr": optimizer_G.param_groups[0]['lr'], |
| | "loss_D_total_epoch": running_loss_D_total/iter, |
| | "loss_G_total_epoch": running_loss_G_total/iter, |
| | "loss_align_epoch": running_loss_align/iter, |
| | "loss_rec_epoch": running_loss_rec/iter, |
| | "loss_res_epoch": running_loss_res/iter, |
| | "loss_sty_epoch": running_loss_sty/iter, |
| | "loss_idt_epoch": running_loss_idt/iter, |
| | "loss_cyc_epoch": running_loss_cyc/iter, |
| | "PSNR_val": total_PSNR/val_iter, |
| | "SSIM_val": total_SSIM/val_iter, |
| | "LPIPS_val": total_LPIPS/val_iter |
| | } |
| | ) |
| |
|
| |
|
| | if (epoch+1) % args.save_freq == 0: |
| | weights_file_name = 'epoch_%d.pth' % (epoch+1) |
| | weights_file = os.path.join(args.snap_path, weights_file_name) |
| | torch.save({ |
| | 'epoch': epoch, |
| |
|
| | 'model_Enc': model_Enc.state_dict(), |
| | 'model_Dec_Id': model_Dec_Id.state_dict(), |
| | 'model_Dec_SR': model_Dec_SR.state_dict(), |
| | 'model_Disc_feat': model_Disc_feat.state_dict(), |
| | 'model_Disc_img_LR': model_Disc_img_LR.state_dict(), |
| | 'model_Disc_img_HR': model_Disc_img_HR.state_dict(), |
| |
|
| | 'optimizer_D': optimizer_D.state_dict(), |
| | 'optimizer_G': optimizer_G.state_dict(), |
| |
|
| | 'scheduler_D': scheduler_D.state_dict(), |
| | 'scheduler_G': scheduler_G.state_dict(), |
| | }, weights_file) |
| | print('save weights of epoch %d' % (epoch+1)) |
| | else: |
| | |
| | total_PSNR = 0 |
| | total_SSIM = 0 |
| | total_LPIPS = 0 |
| | val_iter = 0 |
| | with torch.no_grad(): |
| | model_Enc.eval() |
| | model_Dec_SR.eval() |
| | for batch_idx, batch in enumerate(val_loader): |
| | val_iter += 1 |
| | source = batch["img_LQ"].to(device) |
| | target = batch["img_GT"].to(device) |
| |
|
| | feat = model_Enc(source) |
| | out = model_Dec_SR(feat) |
| |
|
| | total_PSNR += PSNR(out, target) |
| | total_SSIM += SSIM(out, target) |
| | total_LPIPS += LPIPS(out, target) |
| | print("PSNR_val: ", total_PSNR/val_iter) |
| | print("SSIM_val: ", total_SSIM/val_iter) |
| | print("LPIPS_val: ", total_LPIPS/val_iter) |