| import os |
| import glob |
| from tqdm import tqdm |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader |
| from torch.utils.data.distributed import DistributedSampler |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from tensorboardX import SummaryWriter |
|
|
| from backend.inpaint.sttn.auto_sttn import Discriminator |
| from backend.inpaint.sttn.auto_sttn import InpaintGenerator |
| from backend.tools.train.dataset_sttn import Dataset |
| from backend.tools.train.loss_sttn import AdversarialLoss |
|
|
|
|
| class Trainer: |
| def __init__(self, config, debug=False): |
| |
| self.config = config |
| self.epoch = 0 |
| self.iteration = 0 |
| if debug: |
| |
| self.config['trainer']['save_freq'] = 5 |
| self.config['trainer']['valid_freq'] = 5 |
| self.config['trainer']['iterations'] = 5 |
|
|
| |
| self.train_dataset = Dataset(config['data_loader'], split='train', debug=debug) |
| self.train_sampler = None |
| self.train_args = config['trainer'] |
| if config['distributed']: |
| |
| self.train_sampler = DistributedSampler( |
| self.train_dataset, |
| num_replicas=config['world_size'], |
| rank=config['global_rank'] |
| ) |
| self.train_loader = DataLoader( |
| self.train_dataset, |
| batch_size=self.train_args['batch_size'] // config['world_size'], |
| shuffle=(self.train_sampler is None), |
| num_workers=self.train_args['num_workers'], |
| sampler=self.train_sampler |
| ) |
|
|
| |
| self.adversarial_loss = AdversarialLoss(type=self.config['losses']['GAN_LOSS']) |
| self.adversarial_loss = self.adversarial_loss.to(self.config['device']) |
| self.l1_loss = nn.L1Loss() |
|
|
| |
| self.netG = InpaintGenerator() |
| self.netG = self.netG.to(self.config['device']) |
| self.netD = Discriminator( |
| in_channels=3, use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge' |
| ) |
| self.netD = self.netD.to(self.config['device']) |
| |
| self.optimG = torch.optim.Adam( |
| self.netG.parameters(), |
| lr=config['trainer']['lr'], |
| betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2']) |
| ) |
| self.optimD = torch.optim.Adam( |
| self.netD.parameters(), |
| lr=config['trainer']['lr'], |
| betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2']) |
| ) |
| self.load() |
|
|
| if config['distributed']: |
| |
| self.netG = DDP( |
| self.netG, |
| device_ids=[self.config['local_rank']], |
| output_device=self.config['local_rank'], |
| broadcast_buffers=True, |
| find_unused_parameters=False |
| ) |
| self.netD = DDP( |
| self.netD, |
| device_ids=[self.config['local_rank']], |
| output_device=self.config['local_rank'], |
| broadcast_buffers=True, |
| find_unused_parameters=False |
| ) |
|
|
| |
| self.dis_writer = None |
| self.gen_writer = None |
| self.summary = {} |
| if self.config['global_rank'] == 0 or (not config['distributed']): |
| |
| self.dis_writer = SummaryWriter( |
| os.path.join(config['save_dir'], 'dis') |
| ) |
| self.gen_writer = SummaryWriter( |
| os.path.join(config['save_dir'], 'gen') |
| ) |
|
|
| |
| def get_lr(self): |
| return self.optimG.param_groups[0]['lr'] |
|
|
| |
| def adjust_learning_rate(self): |
| |
| decay = 0.1 ** (min(self.iteration, self.config['trainer']['niter_steady']) // self.config['trainer']['niter']) |
| new_lr = self.config['trainer']['lr'] * decay |
| |
| if new_lr != self.get_lr(): |
| for param_group in self.optimG.param_groups: |
| param_group['lr'] = new_lr |
| for param_group in self.optimD.param_groups: |
| param_group['lr'] = new_lr |
|
|
| |
| def add_summary(self, writer, name, val): |
| |
| if name not in self.summary: |
| self.summary[name] = 0 |
| self.summary[name] += val |
| |
| if writer is not None and self.iteration % 100 == 0: |
| writer.add_scalar(name, self.summary[name] / 100, self.iteration) |
| self.summary[name] = 0 |
|
|
| |
| def load(self): |
| model_path = self.config['save_dir'] |
| |
| if os.path.isfile(os.path.join(model_path, 'latest.ckpt')): |
| |
| latest_epoch = open(os.path.join( |
| model_path, 'latest.ckpt'), 'r').read().splitlines()[-1] |
| else: |
| |
| ckpts = [os.path.basename(i).split('.pth')[0] for i in glob.glob( |
| os.path.join(model_path, '*.pth'))] |
| ckpts.sort() |
| latest_epoch = ckpts[-1] if len(ckpts) > 0 else None |
| if latest_epoch is not None: |
| |
| gen_path = os.path.join( |
| model_path, 'gen_{}.pth'.format(str(latest_epoch).zfill(5))) |
| dis_path = os.path.join( |
| model_path, 'dis_{}.pth'.format(str(latest_epoch).zfill(5))) |
| opt_path = os.path.join( |
| model_path, 'opt_{}.pth'.format(str(latest_epoch).zfill(5))) |
| |
| if self.config['global_rank'] == 0: |
| print('Loading model from {}...'.format(gen_path)) |
| |
| data = torch.load(gen_path, map_location=self.config['device']) |
| self.netG.load_state_dict(data['netG']) |
| |
| data = torch.load(dis_path, map_location=self.config['device']) |
| self.netD.load_state_dict(data['netD']) |
| |
| data = torch.load(opt_path, map_location=self.config['device']) |
| self.optimG.load_state_dict(data['optimG']) |
| self.optimD.load_state_dict(data['optimD']) |
| |
| self.epoch = data['epoch'] |
| self.iteration = data['iteration'] |
| else: |
| |
| if self.config['global_rank'] == 0: |
| print('Warning: There is no trained model found. An initialized model will be used.') |
|
|
| |
| def save(self, it): |
| |
| if self.config['global_rank'] == 0: |
| |
| gen_path = os.path.join( |
| self.config['save_dir'], 'gen_{}.pth'.format(str(it).zfill(5))) |
| |
| dis_path = os.path.join( |
| self.config['save_dir'], 'dis_{}.pth'.format(str(it).zfill(5))) |
| |
| opt_path = os.path.join( |
| self.config['save_dir'], 'opt_{}.pth'.format(str(it).zfill(5))) |
|
|
| |
| print('\nsaving model to {} ...'.format(gen_path)) |
|
|
| |
| if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP): |
| netG = self.netG.module |
| netD = self.netD.module |
| else: |
| netG = self.netG |
| netD = self.netD |
|
|
| |
| torch.save({'netG': netG.state_dict()}, gen_path) |
| torch.save({'netD': netD.state_dict()}, dis_path) |
| |
| torch.save({ |
| 'epoch': self.epoch, |
| 'iteration': self.iteration, |
| 'optimG': self.optimG.state_dict(), |
| 'optimD': self.optimD.state_dict() |
| }, opt_path) |
|
|
| |
| os.system('echo {} > {}'.format(str(it).zfill(5), |
| os.path.join(self.config['save_dir'], 'latest.ckpt'))) |
|
|
| |
|
|
| def train(self): |
| |
| pbar = range(int(self.train_args['iterations'])) |
| |
| if self.config['global_rank'] == 0: |
| pbar = tqdm(pbar, initial=self.iteration, dynamic_ncols=True, smoothing=0.01) |
|
|
| |
| while True: |
| self.epoch += 1 |
| if self.config['distributed']: |
| |
| self.train_sampler.set_epoch(self.epoch) |
|
|
| |
| self._train_epoch(pbar) |
| |
| if self.iteration > self.train_args['iterations']: |
| break |
| |
| print('\nEnd training....') |
|
|
| |
|
|
| def _train_epoch(self, pbar): |
| device = self.config['device'] |
|
|
| |
| for frames, masks in self.train_loader: |
| |
| self.adjust_learning_rate() |
| |
| self.iteration += 1 |
|
|
| |
| frames, masks = frames.to(device), masks.to(device) |
| b, t, c, h, w = frames.size() |
| masked_frame = (frames * (1 - masks).float()) |
| pred_img = self.netG(masked_frame, masks) |
| |
| frames = frames.view(b * t, c, h, w) |
| masks = masks.view(b * t, 1, h, w) |
| comp_img = frames * (1. - masks) + masks * pred_img |
|
|
| gen_loss = 0 |
| dis_loss = 0 |
|
|
| |
| real_vid_feat = self.netD(frames) |
| fake_vid_feat = self.netD(comp_img.detach()) |
| dis_real_loss = self.adversarial_loss(real_vid_feat, True, True) |
| dis_fake_loss = self.adversarial_loss(fake_vid_feat, False, True) |
| dis_loss += (dis_real_loss + dis_fake_loss) / 2 |
| |
| self.add_summary(self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item()) |
| self.add_summary(self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item()) |
| |
| self.optimD.zero_grad() |
| dis_loss.backward() |
| self.optimD.step() |
|
|
| |
| gen_vid_feat = self.netD(comp_img) |
| gan_loss = self.adversarial_loss(gen_vid_feat, True, False) |
| gan_loss = gan_loss * self.config['losses']['adversarial_weight'] |
| gen_loss += gan_loss |
| |
| self.add_summary(self.gen_writer, 'loss/gan_loss', gan_loss.item()) |
|
|
| |
| hole_loss = self.l1_loss(pred_img * masks, frames * masks) |
| |
| hole_loss = hole_loss / torch.mean(masks) * self.config['losses']['hole_weight'] |
| gen_loss += hole_loss |
| |
| self.add_summary(self.gen_writer, 'loss/hole_loss', hole_loss.item()) |
|
|
| |
| valid_loss = self.l1_loss(pred_img * (1 - masks), frames * (1 - masks)) |
| |
| valid_loss = valid_loss / torch.mean(1 - masks) * self.config['losses']['valid_weight'] |
| gen_loss += valid_loss |
| |
| self.add_summary(self.gen_writer, 'loss/valid_loss', valid_loss.item()) |
|
|
| |
| self.optimG.zero_grad() |
| gen_loss.backward() |
| self.optimG.step() |
|
|
| |
| if self.config['global_rank'] == 0: |
| pbar.update(1) |
| pbar.set_description(( |
| f"d: {dis_loss.item():.3f}; g: {gan_loss.item():.3f};" |
| f"hole: {hole_loss.item():.3f}; valid: {valid_loss.item():.3f}") |
| ) |
|
|
| |
| if self.iteration % self.train_args['save_freq'] == 0: |
| self.save(int(self.iteration // self.train_args['save_freq'])) |
| |
| if self.iteration > self.train_args['iterations']: |
| break |
|
|
|
|