| |
| |
| |
|
|
| import os |
| import sys |
| import math |
| import time |
| import lpips |
| import random |
| import datetime |
| import functools |
| import numpy as np |
| from pathlib import Path |
| from loguru import logger |
| from copy import deepcopy |
| from omegaconf import OmegaConf |
| from collections import OrderedDict |
| from einops import rearrange |
|
|
| from datapipe.datasets import create_dataset |
| from models.resample import UniformSampler |
|
|
| import torch |
| import torch.nn as nn |
| import torch.cuda.amp as amp |
| import torch.nn.functional as F |
| import torch.utils.data as udata |
| import torch.distributed as dist |
| import torch.multiprocessing as mp |
| import torchvision.utils as vutils |
| from torch.utils.tensorboard import SummaryWriter |
| from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
| from utils import util_net |
| from utils import util_common |
| from utils import util_image |
|
|
| from basicsr.utils import DiffJPEG |
| from basicsr.utils.img_process_util import filter2D |
| from basicsr.data.transforms import paired_random_crop |
| from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt |
|
|
| class TrainerBase: |
| def __init__(self, configs): |
| self.configs = configs |
|
|
| |
| self.setup_dist() |
|
|
| |
| self.setup_seed() |
|
|
| |
| self.init_logger() |
|
|
| |
| if self.rank == 0: self.logger.info(OmegaConf.to_yaml(self.configs)) |
|
|
| |
| self.build_model() |
|
|
| |
| self.setup_optimizaton() |
|
|
| |
| self.resume_from_ckpt() |
|
|
| def setup_dist(self): |
| if self.configs.gpu_id: |
| gpu_id = self.configs.gpu_id |
| num_gpus = len(gpu_id) |
| os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' |
| os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([gpu_id[ii] for ii in range(num_gpus)]) |
| else: |
| num_gpus = torch.cuda.device_count() |
|
|
| if num_gpus > 1: |
| if mp.get_start_method(allow_none=True) is None: |
| mp.set_start_method('spawn') |
| rank = int(os.environ['LOCAL_RANK']) |
| torch.cuda.set_device(rank % num_gpus) |
| dist.init_process_group( |
| backend='nccl', |
| init_method='env://', |
| ) |
|
|
| self.num_gpus = num_gpus |
| self.rank = int(os.environ['LOCAL_RANK']) if num_gpus > 1 else 0 |
|
|
| def setup_seed(self, seed=None): |
| seed = self.configs.seed if seed is None else seed |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
| def init_logger(self): |
| |
| save_dir = Path(self.configs.save_dir) |
| logtxet_path = save_dir / 'training.log' |
| log_dir = save_dir / 'logs' |
| ckpt_dir = save_dir / 'ckpts' |
| self.ckpt_dir = ckpt_dir |
| if self.rank == 0: |
| if not save_dir.exists(): |
| save_dir.mkdir() |
| else: |
| assert self.configs.resume, '''Please check the resume parameter. If you do not |
| want to resume from some checkpoint, please delete |
| the saving folder first.''' |
|
|
| |
| if logtxet_path.exists(): |
| assert self.configs.resume |
| self.logger = logger |
| self.logger.remove() |
| self.logger.add(logtxet_path, format="{message}", mode='a') |
| self.logger.add(sys.stderr, format="{message}") |
|
|
| |
| if not log_dir.exists(): |
| log_dir.mkdir() |
| self.writer = SummaryWriter(str(log_dir)) |
| self.log_step = {phase: 1 for phase in ['train', 'val']} |
| self.log_step_img = {phase: 1 for phase in ['train', 'val']} |
|
|
| if not ckpt_dir.exists(): |
| ckpt_dir.mkdir() |
|
|
| def close_logger(self): |
| if self.rank == 0: self.writer.close() |
|
|
| def resume_from_ckpt(self): |
| if self.configs.resume: |
| if type(self.configs.resume) == bool: |
| ckpt_index = max([int(x.stem.split('_')[1]) for x in Path(self.ckpt_dir).glob('*.pth')]) |
| ckpt_path = str(Path(self.ckpt_dir) / f"model_{ckpt_index}.pth") |
| else: |
| ckpt_path = self.configs.resume |
| assert os.path.isfile(ckpt_path) |
| if self.rank == 0: |
| self.logger.info(f"=> Loaded checkpoint {ckpt_path}") |
| ckpt = torch.load(ckpt_path, map_location=f"cuda:{self.rank}") |
| util_net.reload_model(self.model, ckpt['state_dict']) |
| torch.cuda.empty_cache() |
|
|
| |
| self.iters_start = ckpt['iters_start'] |
| |
| for ii in range(self.iters_start): self.adjust_lr(ii) |
| if self.rank == 0: |
| self.log_step = ckpt['log_step'] |
| self.log_step_img = ckpt['log_step_img'] |
|
|
| |
| self.setup_seed(self.iters_start) |
| else: |
| self.iters_start = 0 |
|
|
| def setup_optimizaton(self): |
| self.optimizer = torch.optim.AdamW(self.model.parameters(), |
| lr=self.configs.train.lr, |
| weight_decay=self.configs.train.weight_decay) |
|
|
| def build_model(self): |
| params = self.configs.model.get('params', dict) |
| model = util_common.get_obj_from_str(self.configs.model.target)(**params) |
| if self.num_gpus > 1: |
| self.model = DDP(model.cuda(), device_ids=[self.rank,]) |
| else: |
| self.model = model.cuda() |
|
|
| |
| if self.rank == 0: |
| self.lpips_loss = lpips.LPIPS(net='vgg').cuda() |
|
|
| |
| self.print_model_info() |
|
|
| def build_dataloader(self): |
| def _wrap_loader(loader): |
| while True: yield from loader |
|
|
| datasets = {} |
| for phase in ['train', ]: |
| dataset_config = self.configs.data.get(phase, dict) |
| datasets[phase] = create_dataset(dataset_config) |
|
|
| dataloaders = {} |
| |
| if self.rank == 0: |
| for phase in ['train',]: |
| length = len(datasets[phase]) |
| self.logger.info('Number of images in {:s} data set: {:d}'.format(phase, length)) |
| if self.num_gpus > 1: |
| shuffle = False |
| sampler = udata.distributed.DistributedSampler(datasets['train'], |
| num_replicas=self.num_gpus, |
| rank=self.rank) |
| else: |
| shuffle = True |
| sampler = None |
| dataloaders['train'] = _wrap_loader(udata.DataLoader( |
| datasets['train'], |
| batch_size=self.configs.train.batch[0] // self.num_gpus, |
| shuffle=shuffle, |
| drop_last=False, |
| num_workers=self.configs.train.num_workers // self.num_gpus, |
| pin_memory=True, |
| prefetch_factor=self.configs.train.prefetch_factor, |
| worker_init_fn=my_worker_init_fn, |
| sampler=sampler)) |
|
|
| self.datasets = datasets |
| self.dataloaders = dataloaders |
| self.sampler = sampler |
|
|
| def print_model_info(self): |
| if self.rank == 0: |
| num_params = util_net.calculate_parameters(self.model) / 1000**2 |
| self.logger.info("Detailed network architecture:") |
| self.logger.info(self.model.__repr__()) |
| self.logger.info(f"Number of parameters: {num_params:.2f}M") |
|
|
| def prepare_data(self, phase='train'): |
| pass |
|
|
| def validation(self): |
| pass |
|
|
| def train(self): |
| self.build_dataloader() |
|
|
| self.model.train() |
| num_iters_epoch = math.ceil(len(self.datasets['train']) / self.configs.train.batch[0]) |
| for ii in range(self.iters_start, self.configs.train.iterations): |
| self.current_iters = ii + 1 |
|
|
| |
| data = self.prepare_data( |
| next(self.dataloaders['train']), |
| self.configs.data.train.type.lower() == 'realesrgan', |
| ) |
|
|
| |
| self.training_step(data) |
|
|
| |
| if (ii+1) % self.configs.train.val_freq == 0 and 'val' in self.dataloaders: |
| if self.rank==0: |
| self.validation() |
|
|
| |
| self.adjust_lr() |
|
|
| |
| if (ii+1) % self.configs.train.save_freq == 0 and self.rank == 0: |
| self.save_ckpt() |
|
|
| if (ii+1) % num_iters_epoch == 0 and not self.sampler is None: |
| self.sampler.set_epoch(ii+1) |
|
|
| |
| if self.rank == 0: |
| self.close_logger() |
|
|
| def training_step(self, data): |
| pass |
|
|
| def adjust_lr(self): |
| if hasattr(self, 'lr_sheduler'): |
| self.lr_sheduler.step() |
|
|
| def save_ckpt(self): |
| ckpt_path = self.ckpt_dir / 'model_{:d}.pth'.format(self.current_iters) |
| torch.save({'iters_start': self.current_iters, |
| 'log_step': {phase:self.log_step[phase] for phase in ['train', 'val']}, |
| 'log_step_img': {phase:self.log_step_img[phase] for phase in ['train', 'val']}, |
| 'state_dict': self.model.state_dict()}, ckpt_path) |
|
|
| class TrainerSR(TrainerBase): |
| def __init__(self, configs): |
| super().__init__(configs) |
|
|
| def mse_loss(self, pred, target): |
| return F.mse_loss(pred, target, reduction='mean') |
|
|
| @torch.no_grad() |
| def _dequeue_and_enqueue(self): |
| """It is the training pair pool for increasing the diversity in a batch. |
| |
| Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a |
| batch could not have different resize scaling factors. Therefore, we employ this training pair pool |
| to increase the degradation diversity in a batch. |
| """ |
| |
| b, c, h, w = self.lq.size() |
| if not hasattr(self, 'queue_size'): |
| self.queue_size = self.configs.data.train.params.get('queue_size', b*50) |
| if not hasattr(self, 'queue_lr'): |
| assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}' |
| self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() |
| _, c, h, w = self.gt.size() |
| self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() |
| self.queue_ptr = 0 |
| if self.queue_ptr == self.queue_size: |
| |
| |
| idx = torch.randperm(self.queue_size) |
| self.queue_lr = self.queue_lr[idx] |
| self.queue_gt = self.queue_gt[idx] |
| |
| lq_dequeue = self.queue_lr[0:b, :, :, :].clone() |
| gt_dequeue = self.queue_gt[0:b, :, :, :].clone() |
| |
| self.queue_lr[0:b, :, :, :] = self.lq.clone() |
| self.queue_gt[0:b, :, :, :] = self.gt.clone() |
|
|
| self.lq = lq_dequeue |
| self.gt = gt_dequeue |
| else: |
| |
| self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone() |
| self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone() |
| self.queue_ptr = self.queue_ptr + b |
|
|
| @torch.no_grad() |
| def prepare_data(self, data, real_esrgan=True): |
| if real_esrgan: |
| if not hasattr(self, 'jpeger'): |
| self.jpeger = DiffJPEG(differentiable=False).cuda() |
|
|
| im_gt = data['gt'].cuda() |
| kernel1 = data['kernel1'].cuda() |
| kernel2 = data['kernel2'].cuda() |
| sinc_kernel = data['sinc_kernel'].cuda() |
|
|
| ori_h, ori_w = im_gt.size()[2:4] |
|
|
| |
| |
| out = filter2D(im_gt, kernel1) |
| |
| updown_type = random.choices( |
| ['up', 'down', 'keep'], |
| self.configs.degradation['resize_prob'], |
| )[0] |
| if updown_type == 'up': |
| scale = random.uniform(1, self.configs.degradation['resize_range'][1]) |
| elif updown_type == 'down': |
| scale = random.uniform(self.configs.degradation['resize_range'][0], 1) |
| else: |
| scale = 1 |
| mode = random.choice(['area', 'bilinear', 'bicubic']) |
| out = F.interpolate(out, scale_factor=scale, mode=mode) |
| |
| gray_noise_prob = self.configs.degradation['gray_noise_prob'] |
| if random.random() < self.configs.degradation['gaussian_noise_prob']: |
| out = random_add_gaussian_noise_pt( |
| out, |
| sigma_range=self.configs.degradation['noise_range'], |
| clip=True, |
| rounds=False, |
| gray_prob=gray_noise_prob, |
| ) |
| else: |
| out = random_add_poisson_noise_pt( |
| out, |
| scale_range=self.configs.degradation['poisson_scale_range'], |
| gray_prob=gray_noise_prob, |
| clip=True, |
| rounds=False) |
| |
| jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.configs.degradation['jpeg_range']) |
| out = torch.clamp(out, 0, 1) |
| out = self.jpeger(out, quality=jpeg_p) |
|
|
| |
| |
| if random.random() < self.configs.degradation['second_blur_prob']: |
| out = filter2D(out, kernel2) |
| |
| updown_type = random.choices( |
| ['up', 'down', 'keep'], |
| self.configs.degradation['resize_prob2'], |
| )[0] |
| if updown_type == 'up': |
| scale = random.uniform(1, self.configs.degradation['resize_range2'][1]) |
| elif updown_type == 'down': |
| scale = random.uniform(self.configs.degradation['resize_range2'][0], 1) |
| else: |
| scale = 1 |
| mode = random.choice(['area', 'bilinear', 'bicubic']) |
| out = F.interpolate( |
| out, |
| size=(int(ori_h / self.configs.model.params.sf * scale), |
| int(ori_w / self.configs.model.params.sf * scale)), |
| mode=mode, |
| ) |
| |
| gray_noise_prob = self.configs.degradation['gray_noise_prob2'] |
| if random.random() < self.configs.degradation['gaussian_noise_prob2']: |
| out = random_add_gaussian_noise_pt( |
| out, |
| sigma_range=self.configs.degradation['noise_range2'], |
| clip=True, |
| rounds=False, |
| gray_prob=gray_noise_prob, |
| ) |
| else: |
| out = random_add_poisson_noise_pt( |
| out, |
| scale_range=self.configs.degradation['poisson_scale_range2'], |
| gray_prob=gray_noise_prob, |
| clip=True, |
| rounds=False, |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| if random.random() < 0.5: |
| |
| mode = random.choice(['area', 'bilinear', 'bicubic']) |
| out = F.interpolate( |
| out, |
| size=(ori_h // self.configs.model.params.sf, |
| ori_w // self.configs.model.params.sf), |
| mode=mode, |
| ) |
| out = filter2D(out, sinc_kernel) |
| |
| jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.configs.degradation['jpeg_range2']) |
| out = torch.clamp(out, 0, 1) |
| out = self.jpeger(out, quality=jpeg_p) |
| else: |
| |
| jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.configs.degradation['jpeg_range2']) |
| out = torch.clamp(out, 0, 1) |
| out = self.jpeger(out, quality=jpeg_p) |
| |
| mode = random.choice(['area', 'bilinear', 'bicubic']) |
| out = F.interpolate( |
| out, |
| size=(ori_h // self.configs.model.params.sf, |
| ori_w // self.configs.model.params.sf), |
| mode=mode, |
| ) |
| out = filter2D(out, sinc_kernel) |
|
|
| |
| im_lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. |
|
|
| |
| gt_size = self.configs.degradation['gt_size'] |
| im_gt, im_lq = paired_random_crop(im_gt, im_lq, gt_size, self.configs.model.params.sf) |
| self.lq, self.gt = im_lq, im_gt |
|
|
| |
| self._dequeue_and_enqueue() |
| |
| self.lq = self.lq.contiguous() |
|
|
| return {'lq':self.lq, 'gt':self.gt} |
| else: |
| return {key:value.cuda() for key, value in data.items()} |
|
|
| def setup_optimizaton(self): |
| super().setup_optimizaton() |
| self.lr_sheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
| self.optimizer, |
| T_max = self.configs.train.iterations, |
| eta_min=self.configs.train.lr_min, |
| ) |
|
|
| def training_step(self, data): |
| current_batchsize = data['lq'].shape[0] |
| micro_batchsize = self.configs.train.microbatch |
| num_grad_accumulate = math.ceil(current_batchsize / micro_batchsize) |
|
|
| self.optimizer.zero_grad() |
| for jj in range(0, current_batchsize, micro_batchsize): |
| micro_data = {key:value[jj:jj+micro_batchsize,] for key, value in data.items()} |
| last_batch = (jj+micro_batchsize >= current_batchsize) |
| hq_pred = self.model(micro_data['lq']) |
| if last_batch or self.num_gpus <= 1: |
| loss = self.loss_fun(hq_pred, micro_data['gt']) / hq_pred.shape[0] |
| else: |
| with self.model.no_sync(): |
| loss = self.loss_fun(hq_pred, micro_data['gt']) / hq_pred.shape[0] |
| loss /= num_grad_accumulate |
| loss.backward() |
|
|
| |
| self.log_step_train(hq_pred, loss, micro_data, flag=last_batch) |
|
|
| self.optimizer.step() |
|
|
| def log_step_train(self, hq_pred, loss, batch, flag=False, phase='train'): |
| ''' |
| param loss: loss value |
| ''' |
| if self.rank == 0: |
| chn = batch['lq'].shape[1] |
| if self.current_iters % self.configs.train.log_freq[0] == 1: |
| self.loss_mean = 0 |
|
|
| self.loss_mean += loss.item() |
|
|
| if self.current_iters % self.configs.train.log_freq[0] == 0 and flag: |
| self.loss_mean /= self.configs.train.log_freq[0] |
| mse_pixel = self.loss_mean / batch['gt'].numel() * batch['gt'].shape[0] |
| log_str = 'Train:{:05d}/{:05d}, Loss:{:.2e}, MSE:{:.2e}, lr:{:.2e}'.format( |
| self.current_iters // 100, |
| self.configs.train.iterations // 100, |
| self.loss_mean, |
| mse_pixel, |
| self.optimizer.param_groups[0]['lr'] |
| ) |
| self.logger.info(log_str) |
| |
| self.writer.add_scalar(f'Loss-Train', self.loss_mean, self.log_step[phase]) |
| self.log_step[phase] += 1 |
| if self.current_iters % self.configs.train.log_freq[1] == 0 and flag: |
| x1 = vutils.make_grid(batch['lq'], normalize=True, scale_each=True) |
| self.writer.add_image("Train LQ Image", x1, self.log_step_img[phase]) |
| x2 = vutils.make_grid(batch['gt'], normalize=True, scale_each=True) |
| self.writer.add_image("Train HQ Image", x2, self.log_step_img[phase]) |
| x3 = vutils.make_grid(hq_pred.detach().data, normalize=True, scale_each=True) |
| self.writer.add_image("Train Recovered Image", x3, self.log_step_img[phase]) |
| self.log_step_img[phase] += 1 |
|
|
| if self.current_iters % self.configs.train.save_freq == 1 and flag: |
| self.tic = time.time() |
| if self.current_iters % self.configs.train.save_freq == 0 and flag: |
| self.toc = time.time() |
| elaplsed = (self.toc - self.tic) |
| self.logger.info(f"Elapsed time: {elaplsed:.2f}s") |
| self.logger.info("="*60) |
|
|
| def validation(self, phase='val'): |
| if self.rank == 0: |
| self.model.eval() |
| psnr_mean = lpips_mean = 0 |
| total_iters = math.ceil(len(self.datasets[phase]) / self.configs.train.batch[1]) |
| for ii, data in enumerate(self.dataloaders[phase]): |
| data = self.prepare_data(data) |
| with torch.no_grad(): |
| hq_pred = self.model(data['lq']) |
| hq_pred.clamp_(0.0, 1.0) |
| lpips = self.lpips_loss( |
| util_image.normalize_th(hq_pred, reverse=False), |
| util_image.normalize_th(data['gt'], reverse=False), |
| ).sum().item() |
| psnr = util_image.batch_PSNR( |
| hq_pred, |
| data['gt'], |
| ycbcr=True |
| ) |
|
|
| psnr_mean += psnr |
| lpips_mean += lpips |
|
|
| if (ii+1) % self.configs.train.log_freq[2] == 0: |
| log_str = '{:s}:{:03d}/{:03d}, PSNR={:5.2f}, LPIPS={:6.4f}'.format( |
| phase, |
| ii+1, |
| total_iters, |
| psnr / hq_pred.shape[0], |
| lpips / hq_pred.shape[0] |
| ) |
| self.logger.info(log_str) |
| x1 = vutils.make_grid(data['lq'], normalize=True, scale_each=True) |
| self.writer.add_image("Validation LQ Image", x1, self.log_step_img[phase]) |
| x2 = vutils.make_grid(data['gt'], normalize=True, scale_each=True) |
| self.writer.add_image("Validation HQ Image", x2, self.log_step_img[phase]) |
| x3 = vutils.make_grid(hq_pred.detach().data, normalize=True, scale_each=True) |
| self.writer.add_image("Validation Recovered Image", x3, self.log_step_img[phase]) |
| self.log_step_img[phase] += 1 |
|
|
| psnr_mean /= len(self.datasets[phase]) |
| lpips_mean /= len(self.datasets[phase]) |
| |
| self.writer.add_scalar('Validation PSRN', psnr_mean, self.log_step[phase]) |
| self.writer.add_scalar('Validation LPIPS', lpips_mean, self.log_step[phase]) |
| self.log_step[phase] += 1 |
| |
| self.logger.info(f'PSNR={psnr_mean:5.2f}, LPIPS={lpips_mean:6.4f}') |
| self.logger.info("="*60) |
|
|
| self.model.train() |
|
|
| def build_dataloader(self): |
| super().build_dataloader() |
| if self.rank == 0 and 'val' in self.configs.data: |
| dataset_config = self.configs.data.get('val', dict) |
| self.datasets['val'] = create_dataset(dataset_config) |
| self.dataloaders['val'] = udata.DataLoader( |
| self.datasets['val'], |
| batch_size=self.configs.train.batch[1], |
| shuffle=False, |
| drop_last=False, |
| num_workers=0, |
| pin_memory=True, |
| ) |
|
|
| class TrainerDiffusionFace(TrainerBase): |
| def __init__(self, configs): |
| |
| self.ema_rates = OmegaConf.to_object(configs.train.ema_rates) |
| super().__init__(configs) |
|
|
| def init_logger(self): |
| super().init_logger() |
|
|
| save_dir = Path(self.configs.save_dir) |
| ema_ckpt_dir = save_dir / 'ema_ckpts' |
| if self.rank == 0: |
| if not ema_ckpt_dir.exists(): |
| util_common.mkdir(ema_ckpt_dir, delete=False, parents=False) |
| else: |
| if not self.configs.resume: |
| util_common.mkdir(ema_ckpt_dir, delete=True, parents=False) |
|
|
| self.ema_ckpt_dir = ema_ckpt_dir |
|
|
| def resume_from_ckpt(self): |
| super().resume_from_ckpt() |
|
|
| def _load_ema_state(ema_state, ckpt): |
| for key in ema_state.keys(): |
| ema_state[key] = deepcopy(ckpt[key].detach().data) |
|
|
| if self.configs.resume: |
| |
| if type(self.configs.resume) == bool: |
| ckpt_index = max([int(x.stem.split('_')[1]) for x in Path(self.ckpt_dir).glob('*.pth')]) |
| ckpt_path = str(Path(self.ckpt_dir) / f"model_{ckpt_index}.pth") |
| else: |
| ckpt_path = self.configs.resume |
| assert os.path.isfile(ckpt_path) |
| |
| for rate in self.ema_rates: |
| ema_ckpt_path = self.ema_ckpt_dir / (f"ema0{int(rate*1000)}_"+Path(ckpt_path).name) |
| ema_ckpt = torch.load(ema_ckpt_path, map_location=f"cuda:{self.rank}") |
| _load_ema_state(self.ema_state[f"0{int(rate*1000)}"], ema_ckpt) |
|
|
| def build_model(self): |
| params = self.configs.model.get('params', dict) |
| model = util_common.get_obj_from_str(self.configs.model.target)(**params) |
| self.ema_model = deepcopy(model.cuda()) |
| if self.num_gpus > 1: |
| self.model = DDP(model.cuda(), device_ids=[self.rank,]) |
| else: |
| self.model = model.cuda() |
|
|
| self.ema_state = {} |
| for rate in self.ema_rates: |
| self.ema_state[f"0{int(rate*1000)}"] = OrderedDict( |
| {key:deepcopy(value.data) for key, value in self.model.state_dict().items()} |
| ) |
|
|
| |
| self.print_model_info() |
|
|
| params = self.configs.diffusion.get('params', dict) |
| self.base_diffusion = util_common.get_obj_from_str(self.configs.diffusion.target)(**params) |
| self.sample_scheduler_diffusion = UniformSampler(self.base_diffusion.num_timesteps) |
|
|
| def prepare_data(self, data, realesrgan=False): |
| data = {key:value.cuda() for key, value in data.items()} |
| return data |
|
|
| def training_step(self, data): |
| current_batchsize = data['image'].shape[0] |
| micro_batchsize = self.configs.train.microbatch |
| num_grad_accumulate = math.ceil(current_batchsize / micro_batchsize) |
|
|
| if self.configs.train.use_fp16: |
| scaler = amp.GradScaler() |
|
|
| self.optimizer.zero_grad() |
| for jj in range(0, current_batchsize, micro_batchsize): |
| micro_data = {key:value[jj:jj+micro_batchsize,] for key, value in data.items()} |
| last_batch = (jj+micro_batchsize >= current_batchsize) |
| tt, weights = self.sample_scheduler_diffusion.sample( |
| micro_data['image'].shape[0], |
| device=f"cuda:{self.rank}", |
| use_fp16=self.configs.train.use_fp16 |
| ) |
| compute_losses = functools.partial( |
| self.base_diffusion.training_losses, |
| self.model, |
| micro_data['image'], |
| tt, |
| model_kwargs={'y':micro_data['label']} if 'label' in micro_data else None, |
| ) |
| if self.configs.train.use_fp16: |
| with amp.autocast(): |
| if last_batch or self.num_gpus <= 1: |
| losses = compute_losses() |
| else: |
| with self.model.no_sync(): |
| losses = compute_losses() |
| loss = (losses["loss"] * weights).mean() / num_grad_accumulate |
| scaler.scale(loss).backward() |
| else: |
| if last_batch or self.num_gpus <= 1: |
| losses = compute_losses() |
| else: |
| with self.model.no_sync(): |
| losses = compute_losses() |
| loss = (losses["loss"] * weights).mean() / num_grad_accumulate |
| loss.backward() |
|
|
| |
| self.log_step_train(losses, tt, micro_data, last_batch) |
|
|
| if self.configs.train.use_fp16: |
| scaler.step(self.optimizer) |
| scaler.update() |
| else: |
| self.optimizer.step() |
|
|
| self.update_ema_model() |
|
|
| def update_ema_model(self): |
| if self.num_gpus > 1: |
| dist.barrier() |
| if self.rank == 0: |
| for rate in self.ema_rates: |
| ema_state = self.ema_state[f"0{int(rate*1000)}"] |
| source_state = self.model.state_dict() |
| for key, value in ema_state.items(): |
| ema_state[key].mul_(rate).add_(source_state[key].detach().data, alpha=1-rate) |
|
|
| def adjust_lr(self, ii): |
| base_lr = self.configs.train.lr |
| linear_steps = self.configs.train.milestones[0] |
| if ii <= linear_steps: |
| for params_group in self.optimizer.param_groups: |
| params_group['lr'] = (ii / linear_steps) * base_lr |
| elif ii in self.configs.train.milestones: |
| for params_group in self.optimizer.param_groups: |
| params_group['lr'] *= 0.5 |
|
|
| def log_step_train(self, loss, tt, batch, flag=False, phase='train'): |
| ''' |
| param loss: a dict recording the loss informations |
| param tt: 1-D tensor, time steps |
| ''' |
| if self.rank == 0: |
| chn = batch['image'].shape[1] |
| num_timesteps = self.base_diffusion.num_timesteps |
| if self.current_iters % self.configs.train.log_freq[0] == 1: |
| self.loss_mean = {key:torch.zeros(size=(num_timesteps,), dtype=torch.float64) |
| for key in loss.keys()} |
| self.loss_count = torch.zeros(size=(num_timesteps,), dtype=torch.float64) |
| for key, value in loss.items(): |
| self.loss_mean[key][tt, ] += value.detach().data.cpu() |
| self.loss_count[tt,] += 1 |
|
|
| if self.current_iters % self.configs.train.log_freq[0] == 0 and flag: |
| if torch.any(self.loss_count == 0): |
| self.loss_count += 1e-4 |
| for key, value in loss.items(): |
| self.loss_mean[key] /= self.loss_count |
| log_str = 'Train: {:05d}/{:05d}, Loss: '.format( |
| self.current_iters // 100, |
| self.configs.train.iterations // 100) |
| for kk in [1, num_timesteps // 2, num_timesteps]: |
| if 'vb' in self.loss_mean: |
| log_str += 't({:d}):{:.2e}/{:.2e}/{:.2e}, '.format( |
| kk, |
| self.loss_mean['loss'][kk-1].item(), |
| self.loss_mean['mse'][kk-1].item(), |
| self.loss_mean['vb'][kk-1].item(), |
| ) |
| else: |
| log_str += 't({:d}):{:.2e}, '.format(kk, self.loss_mean['loss'][kk-1].item()) |
| log_str += 'lr:{:.2e}'.format(self.optimizer.param_groups[0]['lr']) |
| self.logger.info(log_str) |
| |
| for kk in [1, num_timesteps // 2, num_timesteps]: |
| self.writer.add_scalar(f'Loss-Step-{kk}', |
| self.loss_mean['loss'][kk-1].item(), |
| self.log_step[phase]) |
| self.log_step[phase] += 1 |
| if self.current_iters % self.configs.train.log_freq[1] == 0 and flag: |
| x1 = vutils.make_grid(batch['image'], normalize=True, scale_each=True) |
| self.writer.add_image("Training Image", x1, self.log_step_img[phase]) |
| self.log_step_img[phase] += 1 |
|
|
| if self.current_iters % self.configs.train.save_freq == 1 and flag: |
| self.tic = time.time() |
| if self.current_iters % self.configs.train.save_freq == 0 and flag: |
| self.toc = time.time() |
| elaplsed = (self.toc - self.tic) * num_timesteps / (num_timesteps - 1) |
| self.logger.info(f"Elapsed time: {elaplsed:.2f}s") |
| self.logger.info("="*130) |
|
|
| def validation(self, phase='val'): |
| self.reload_ema_model(self.ema_rates[0]) |
| self.ema_model.eval() |
|
|
| indices = [int(self.base_diffusion.num_timesteps * x) for x in [0.25, 0.5, 0.75, 1]] |
| chn = 3 |
| batch_size = self.configs.train.batch[1] |
| shape = (batch_size, chn,) + (self.configs.data.train.params.out_size,) * 2 |
| num_iters = 0 |
| |
| |
| |
| for sample in self.base_diffusion.p_sample_loop_progressive( |
| model = self.ema_model, |
| shape = shape, |
| noise = None, |
| clip_denoised = True, |
| model_kwargs = None, |
| device = f"cuda:{self.rank}", |
| progress=False |
| ): |
| num_iters += 1 |
| img = util_image.normalize_th(sample['sample'], reverse=True) |
| if num_iters == 1: |
| im_recover = img |
| elif num_iters in indices: |
| im_recover_last = img |
| im_recover = torch.cat((im_recover, im_recover_last), dim=1) |
| im_recover = rearrange(im_recover, 'b (k c) h w -> (b k) c h w', c=chn) |
| x1 = vutils.make_grid(im_recover, nrow=len(indices)+1, normalize=False) |
| self.writer.add_image('Validation Sample', x1, self.log_step_img[phase]) |
| self.log_step_img[phase] += 1 |
|
|
| def save_ckpt(self): |
| if self.rank == 0: |
| ckpt_path = self.ckpt_dir / 'model_{:d}.pth'.format(self.current_iters) |
| torch.save({'iters_start': self.current_iters, |
| 'log_step': {phase:self.log_step[phase] for phase in ['train', 'val']}, |
| 'log_step_img': {phase:self.log_step_img[phase] for phase in ['train', 'val']}, |
| 'state_dict': self.model.state_dict()}, ckpt_path) |
| for rate in self.ema_rates: |
| ema_ckpt_path = self.ema_ckpt_dir / (f"ema0{int(rate*1000)}_"+ckpt_path.name) |
| torch.save(self.ema_state[f"0{int(rate*1000)}"], ema_ckpt_path) |
|
|
| def calculate_lpips(self, inputs, targets): |
| inputs, targets = [(x-0.5)/0.5 for x in [inputs, targets]] |
| with torch.no_grad(): |
| mean_lpips = self.lpips_loss(inputs, targets) |
| return mean_lpips.mean().item() |
|
|
| def reload_ema_model(self, rate): |
| model_state = {key[7:]:value for key, value in self.ema_state[f"0{int(rate*1000)}"].items()} |
| self.ema_model.load_state_dict(model_state) |
|
|
| def my_worker_init_fn(worker_id): |
| np.random.seed(np.random.get_state()[1][0] + worker_id) |
|
|
| if __name__ == '__main__': |
| from utils import util_image |
| from einops import rearrange |
| im1 = util_image.imread('./testdata/inpainting/val/places/Places365_val_00012685_crop000.png', |
| chn = 'rgb', dtype='float32') |
| im2 = util_image.imread('./testdata/inpainting/val/places/Places365_val_00014886_crop000.png', |
| chn = 'rgb', dtype='float32') |
| im = rearrange(np.stack((im1, im2), 3), 'h w c b -> b c h w') |
| im_grid = im.copy() |
| for alpha in [0.8, 0.4, 0.1, 0]: |
| im_new = im * alpha + np.random.randn(*im.shape) * (1 - alpha) |
| im_grid = np.concatenate((im_new, im_grid), 1) |
|
|
| im_grid = np.clip(im_grid, 0.0, 1.0) |
| im_grid = rearrange(im_grid, 'b (k c) h w -> (b k) c h w', k=5) |
| xx = vutils.make_grid(torch.from_numpy(im_grid), nrow=5, normalize=True, scale_each=True).numpy() |
| util_image.imshow(np.concatenate((im1, im2), 0)) |
| util_image.imshow(xx.transpose((1,2,0))) |
|
|
|
|