| | |
| |
|
| | import os, sys |
| | import torch |
| | import glob |
| | import time, shutil |
| | import math |
| | import gc |
| | from tqdm import tqdm |
| | from collections import defaultdict |
| |
|
| | |
| | from torch.multiprocessing import Pool, Process, set_start_method |
| | from torch.utils.tensorboard import SummaryWriter |
| | from torch.utils.data import DataLoader |
| |
|
| |
|
| | try: |
| | set_start_method('spawn') |
| | except RuntimeError: |
| | pass |
| |
|
| |
|
| | |
| | root_path = os.path.abspath('.') |
| | sys.path.append(root_path) |
| | from loss.gan_loss import GANLoss, MultiScaleGANLoss |
| | from loss.pixel_loss import PixelLoss, L1_Charbonnier_loss |
| | from loss.perceptual_loss import PerceptualLoss |
| | from loss.anime_perceptual_loss import Anime_PerceptualLoss |
| | from architecture.dataset import ImageDataset |
| | from scripts.generate_lr_esr import generate_low_res_esr |
| |
|
| |
|
| | |
| | scaler = torch.cuda.amp.GradScaler() |
| |
|
| | class train_master(object): |
| | def __init__(self, options, args, model_name, has_discriminator=False) -> None: |
| | |
| | self.args = args |
| | self.model_name = model_name |
| | self.options = options |
| | self.has_discriminator = has_discriminator |
| |
|
| | |
| | self.loss_init() |
| |
|
| | |
| | self.call_model() |
| |
|
| | |
| | self.learning_rate = options['start_learning_rate'] |
| | self.optimizer_g = torch.optim.Adam(self.generator.parameters(), lr=self.learning_rate, betas=(options["adam_beta1"], options["adam_beta2"])) |
| | if self.has_discriminator: |
| | self.optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=self.learning_rate, betas=(self.options["adam_beta1"], self.options["adam_beta2"])) |
| |
|
| | |
| | self.start_iteration = 0 |
| | self.lowest_generator_loss = float("inf") |
| |
|
| | |
| | self.writer = SummaryWriter() |
| | self.weight_store = defaultdict(int) |
| |
|
| | |
| | self.n_iterations = options['train_iterations'] |
| | self.batch_size = options['train_batch_size'] |
| | self.n_cpu = options['train_dataloader_workers'] |
| |
|
| |
|
| | def adjust_learning_rate(self, iteration_idx): |
| | self.learning_rate = self.options['start_learning_rate'] |
| | end_iteration = self.options['train_iterations'] |
| |
|
| | |
| | for idx in range(min(end_iteration, iteration_idx)//self.options['decay_iteration']): |
| | idx = idx+1 |
| | if idx * self.options['decay_iteration'] in self.options['double_milestones']: |
| | |
| | self.learning_rate = self.learning_rate * 2 |
| | else: |
| | |
| | self.learning_rate = self.learning_rate * self.options['decay_gamma'] |
| |
|
| | |
| | for param_group in self.optimizer_g.param_groups: |
| | param_group['lr'] = self.learning_rate |
| | |
| | if self.has_discriminator: |
| | |
| | for param_group in self.optimizer_d.param_groups: |
| | param_group['lr'] = self.learning_rate |
| |
|
| | assert(self.learning_rate == self.optimizer_g.param_groups[0]['lr']) |
| |
|
| |
|
| | def pixel_loss_load(self): |
| | if self.options['pixel_loss'] == "L1": |
| | self.cri_pix = PixelLoss().cuda() |
| | elif self.options['pixel_loss'] == "L1_Charbonnier": |
| | self.cri_pix = L1_Charbonnier_loss().cuda() |
| |
|
| | print("We are using {} loss".format(self.options['pixel_loss'])) |
| | |
| |
|
| | def GAN_loss_load(self): |
| | |
| | gan_loss_weight = self.options["gan_loss_weight"] |
| | vgg_type = self.options['train_perceptual_vgg_type'] |
| |
|
| | |
| | self.cri_pix = torch.nn.L1Loss().cuda() |
| | self.cri_vgg_perceptual = PerceptualLoss(self.options['train_perceptual_layer_weights'], vgg_type, perceptual_weight=self.options["vgg_perceptual_loss_weight"]).cuda() |
| | self.cri_danbooru_perceptual = Anime_PerceptualLoss(self.options["Danbooru_layer_weights"], perceptual_weight=self.options["danbooru_perceptual_loss_weight"]).cuda() |
| |
|
| | |
| | if self.options['discriminator_type'] == "PatchDiscriminator": |
| | self.cri_gan = MultiScaleGANLoss(gan_type="lsgan", loss_weight=gan_loss_weight).cuda() |
| | elif self.options['discriminator_type'] == "UNetDiscriminator": |
| | self.cri_gan = GANLoss(gan_type="vanilla", loss_weight=gan_loss_weight).cuda() |
| |
|
| | def tensorboard_epoch_draw(self, epoch_loss, epoch): |
| | self.writer.add_scalar('Loss/train-Loss-Epoch', epoch_loss, epoch) |
| |
|
| |
|
| | def master_run(self): |
| | torch.backends.cudnn.benchmark = True |
| | print("options are ", self.options) |
| |
|
| | |
| | self.generate_lr() |
| |
|
| | |
| | train_lr_paths = glob.glob(self.options["lr_dataset_path"] + "/*.*") |
| | degrade_hr_paths = glob.glob(self.options["degrade_hr_dataset_path"] + "/*.*") |
| | train_hr_paths = glob.glob(self.options["train_hr_dataset_path"] + "/*.*") |
| | train_dataloader = DataLoader(ImageDataset(train_lr_paths, degrade_hr_paths, train_hr_paths), batch_size=self.batch_size, shuffle=True, num_workers=self.n_cpu) |
| | dataset_length = len(os.listdir(self.options["train_hr_dataset_path"])) |
| |
|
| |
|
| | |
| | if self.args.auto_resume_best or self.args.auto_resume_closest: |
| | self.load_weight(self.model_name) |
| | elif self.args.pretrained_path != "": |
| | self.load_pretrained(self.model_name) |
| |
|
| | |
| | start_epoch = self.start_iteration // math.ceil(dataset_length / self.options['train_batch_size']) |
| | n_epochs = self.n_iterations // math.ceil(dataset_length / self.options['train_batch_size']) |
| | iteration_idx = self.start_iteration |
| | self.batch_idx = iteration_idx |
| | self.adjust_learning_rate(iteration_idx) |
| |
|
| | for epoch in range(start_epoch, n_epochs): |
| | print("This is epoch {} and the start iteration is {} with learning rate {}".format(epoch, iteration_idx, self.optimizer_g.param_groups[0]['lr'])) |
| |
|
| | |
| | if epoch != start_epoch and epoch % self.options['degradate_generation_freq'] == 0: |
| | self.generate_lr() |
| |
|
| | |
| | loss_per_epoch = 0.0 |
| | self.generator.train() |
| | tqdm_bar = tqdm(train_dataloader, total=len(train_dataloader)) |
| | for batch_idx, imgs in enumerate(tqdm_bar): |
| |
|
| | imgs_lr = imgs["lr"].cuda() |
| | imgs_degrade_hr = imgs["degrade_hr"].cuda() |
| | imgs_hr = imgs["hr"].cuda() |
| |
|
| | |
| | self.generator_loss = 0 |
| | self.single_iteration(imgs_lr, imgs_degrade_hr, imgs_hr) |
| | |
| | |
| | self.tensorboard_report(iteration_idx) |
| | loss_per_epoch += self.generator_loss.item() |
| | |
| | |
| | if self.lowest_generator_loss >= self.generator_loss.item(): |
| | self.lowest_generator_loss = self.generator_loss.item() |
| | print("\nSave model with the lowest generator_loss among all iteartions ", self.lowest_generator_loss) |
| |
|
| | |
| | self.save_weight(iteration_idx, self.model_name+"_best", self.options) |
| |
|
| | self.lowest_tensorboard_report(iteration_idx) |
| | |
| | |
| | iteration_idx += 1 |
| | self.batch_idx = iteration_idx |
| | if iteration_idx % self.options['decay_iteration'] == 0: |
| | self.adjust_learning_rate(iteration_idx) |
| | print("Update the learning rate to {} at iteration {} ".format(self.optimizer_g.param_groups[0]['lr'], iteration_idx)) |
| |
|
| | |
| | |
| | |
| | self.tensorboard_epoch_draw( loss_per_epoch/batch_idx, epoch) |
| | |
| |
|
| | |
| | self.save_weight(iteration_idx, self.model_name+"_closest", self.options) |
| | |
| | if epoch % self.options['checkpoints_freq'] == 0 or epoch == n_epochs-1: |
| | self.save_weight(iteration_idx, "checkpoints/" + self.model_name + "_epoch_" + str(epoch), self.options) |
| |
|
| |
|
| | |
| | torch.cuda.empty_cache() |
| | time.sleep(5) |
| | |
| |
|
| |
|
| | def single_iteration(self, imgs_lr, imgs_degrade_hr, imgs_hr): |
| |
|
| | |
| | self.optimizer_g.zero_grad() |
| | if self.has_discriminator: |
| | for p in self.discriminator.parameters(): |
| | p.requires_grad = False |
| |
|
| | with torch.cuda.amp.autocast(): |
| | |
| | gen_hr = self.generator(imgs_lr) |
| |
|
| | |
| | self.calculate_loss(gen_hr, imgs_hr) |
| |
|
| | |
| | |
| | |
| | scaler.scale(self.generator_loss).backward() |
| | scaler.step(self.optimizer_g) |
| | scaler.update() |
| | |
| |
|
| | |
| | if self.has_discriminator: |
| | |
| | for p in self.discriminator.parameters(): |
| | p.requires_grad = True |
| |
|
| | self.optimizer_d.zero_grad() |
| |
|
| | |
| | with torch.cuda.amp.autocast(): |
| | |
| | real_d_preds = self.discriminator(imgs_degrade_hr) |
| | l_d_real = self.cri_gan(real_d_preds, True, is_disc=True) |
| | scaler.scale(l_d_real).backward() |
| |
|
| |
|
| | |
| | with torch.cuda.amp.autocast(): |
| | fake_d_preds = self.discriminator(gen_hr.detach().clone()) |
| | l_d_fake = self.cri_gan(fake_d_preds, False, is_disc=True) |
| | scaler.scale(l_d_fake).backward() |
| |
|
| | |
| | scaler.step(self.optimizer_d) |
| | scaler.update() |
| | |
| |
|
| | |
| | def load_pretrained(self, name): |
| | |
| |
|
| | weight_dir = self.args.pretrained_path |
| | if not os.path.exists(weight_dir): |
| | print("No such pretrained "+weight_dir+" file exists! We end the program! Please check the dir!") |
| | os._exit(0) |
| | |
| | checkpoint_g = torch.load(weight_dir) |
| | if 'model_state_dict' in checkpoint_g: |
| | self.generator.load_state_dict(checkpoint_g['model_state_dict']) |
| | elif 'params_ema' in checkpoint_g: |
| | self.generator.load_state_dict(checkpoint_g['params_ema']) |
| | else: |
| | raise NotImplementedError("We didn't cannot locate the weight of thie pretrained weight") |
| | |
| | print(f"We will use pretrained "+name+" weight!") |
| | |
| |
|
| | def load_weight(self, head_prefix): |
| | |
| | head = head_prefix+"_best" if self.args.auto_resume_best else head_prefix+"_closest" |
| | |
| | if os.path.exists("saved_models/"+head+"_generator.pth"): |
| | print("We need to resume previous " + head + " weight") |
| |
|
| | |
| | checkpoint_g = torch.load("saved_models/"+head+"_generator.pth") |
| | self.generator.load_state_dict(checkpoint_g['model_state_dict']) |
| | self.optimizer_g.load_state_dict(checkpoint_g['optimizer_state_dict']) |
| |
|
| | |
| | if self.has_discriminator: |
| | checkpoint_d = torch.load("saved_models/"+head+"_discriminator.pth") |
| | self.discriminator.load_state_dict(checkpoint_d['model_state_dict']) |
| | self.optimizer_d.load_state_dict(checkpoint_d['optimizer_state_dict']) |
| | assert(checkpoint_g['iteration'] == checkpoint_d['iteration']) |
| |
|
| | self.start_iteration = checkpoint_g['iteration'] + 1 |
| | |
| | |
| | if os.path.exists("saved_models/" + head_prefix + "_best_generator.pth"): |
| | checkpoint_g = torch.load("saved_models/" + head_prefix + "_best_generator.pth") |
| | else: |
| | print("There is no best weight exists!") |
| | self.lowest_generator_loss = min(self.lowest_generator_loss, checkpoint_g["lowest_generator_weight"] ) |
| | print("The lowest generator loss at the beginning is ", self.lowest_generator_loss) |
| | else: |
| | print(f"No saved_models/"+head+"_generator.pth " or " saved_models/"+head+"_discriminator.pth exists") |
| |
|
| |
|
| | print(f"We will start from the iteration {self.start_iteration}") |
| |
|
| |
|
| |
|
| | def save_weight(self, iteration, name, opt): |
| |
|
| | |
| | torch.save({ |
| | 'iteration': iteration, |
| | 'model_state_dict': self.generator.state_dict(), |
| | 'optimizer_state_dict': self.optimizer_g.state_dict(), |
| | 'lowest_generator_weight': self.lowest_generator_loss, |
| | 'opt': opt, |
| | }, "saved_models/" + name + "_generator.pth") |
| | |
| | |
| | |
| |
|
| |
|
| | if self.has_discriminator: |
| | |
| | torch.save({ |
| | 'iteration': iteration, |
| | 'model_state_dict': self.discriminator.state_dict(), |
| | 'optimizer_state_dict': self.optimizer_d.state_dict(), |
| | }, "saved_models/" + name + "_discriminator.pth") |
| |
|
| |
|
| | def lowest_tensorboard_report(self, iteration): |
| | self.writer.add_scalar('Loss/lowest-weight', self.generator_loss, iteration) |
| |
|
| |
|
| | @torch.no_grad() |
| | def generate_lr(self): |
| |
|
| | |
| | os.system("python scripts/generate_lr_esr.py") |
| |
|
| |
|
| | |
| | lr_paths = os.listdir(self.options["lr_dataset_path"]) |
| | degrade_hr_paths = os.listdir(self.options["degrade_hr_dataset_path"]) |
| | hr_paths = os.listdir(self.options["train_hr_dataset_path"]) |
| | |
| | assert(len(lr_paths) == len(degrade_hr_paths)) |
| | assert(len(lr_paths) == len(hr_paths)) |
| |
|
| |
|
| | |
| |
|
| | |
| |
|