| import math |
| import sys |
| from collections import OrderedDict |
|
|
| sys.path.append('..') |
| import lpips |
| import torch |
| import torch.nn.functional as F |
| from torchvision.utils import save_image |
|
|
| from models.archs.vqgan_arch import (Decoder, Discriminator, Encoder, |
| VectorQuantizer, VectorQuantizerTexture) |
| from models.losses.segmentation_loss import BCELossWithQuant |
| from models.losses.vqgan_loss import (DiffAugment, adopt_weight, |
| calculate_adaptive_weight, hinge_d_loss) |
|
|
|
|
| class VQModel(): |
|
|
| def __init__(self, opt): |
| super().__init__() |
| self.opt = opt |
| self.device = torch.device('cuda') |
| self.encoder = Encoder( |
| ch=opt['ch'], |
| num_res_blocks=opt['num_res_blocks'], |
| attn_resolutions=opt['attn_resolutions'], |
| ch_mult=opt['ch_mult'], |
| in_channels=opt['in_channels'], |
| resolution=opt['resolution'], |
| z_channels=opt['z_channels'], |
| double_z=opt['double_z'], |
| dropout=opt['dropout']).to(self.device) |
| self.decoder = Decoder( |
| in_channels=opt['in_channels'], |
| resolution=opt['resolution'], |
| z_channels=opt['z_channels'], |
| ch=opt['ch'], |
| out_ch=opt['out_ch'], |
| num_res_blocks=opt['num_res_blocks'], |
| attn_resolutions=opt['attn_resolutions'], |
| ch_mult=opt['ch_mult'], |
| dropout=opt['dropout'], |
| resamp_with_conv=True, |
| give_pre_end=False).to(self.device) |
| self.quantize = VectorQuantizer( |
| opt['n_embed'], opt['embed_dim'], beta=0.25).to(self.device) |
| self.quant_conv = torch.nn.Conv2d(opt["z_channels"], opt['embed_dim'], |
| 1).to(self.device) |
| self.post_quant_conv = torch.nn.Conv2d(opt['embed_dim'], |
| opt["z_channels"], |
| 1).to(self.device) |
|
|
| def init_training_settings(self): |
| self.loss = BCELossWithQuant() |
| self.log_dict = OrderedDict() |
| self.configure_optimizers() |
|
|
| def save_network(self, save_path): |
| """Save networks. |
| |
| Args: |
| net (nn.Module): Network to be saved. |
| net_label (str): Network label. |
| current_iter (int): Current iter number. |
| """ |
|
|
| save_dict = {} |
| save_dict['encoder'] = self.encoder.state_dict() |
| save_dict['decoder'] = self.decoder.state_dict() |
| save_dict['quantize'] = self.quantize.state_dict() |
| save_dict['quant_conv'] = self.quant_conv.state_dict() |
| save_dict['post_quant_conv'] = self.post_quant_conv.state_dict() |
| save_dict['discriminator'] = self.disc.state_dict() |
| torch.save(save_dict, save_path) |
|
|
| def load_network(self): |
| checkpoint = torch.load(self.opt['pretrained_models']) |
| self.encoder.load_state_dict(checkpoint['encoder'], strict=True) |
| self.decoder.load_state_dict(checkpoint['decoder'], strict=True) |
| self.quantize.load_state_dict(checkpoint['quantize'], strict=True) |
| self.quant_conv.load_state_dict(checkpoint['quant_conv'], strict=True) |
| self.post_quant_conv.load_state_dict( |
| checkpoint['post_quant_conv'], strict=True) |
|
|
| def optimize_parameters(self, data, current_iter): |
| self.encoder.train() |
| self.decoder.train() |
| self.quantize.train() |
| self.quant_conv.train() |
| self.post_quant_conv.train() |
|
|
| loss = self.training_step(data) |
| self.optimizer.zero_grad() |
| loss.backward() |
| self.optimizer.step() |
|
|
| def encode(self, x): |
| h = self.encoder(x) |
| h = self.quant_conv(h) |
| quant, emb_loss, info = self.quantize(h) |
| return quant, emb_loss, info |
|
|
| def decode(self, quant): |
| quant = self.post_quant_conv(quant) |
| dec = self.decoder(quant) |
| return dec |
|
|
| def decode_code(self, code_b): |
| quant_b = self.quantize.embed_code(code_b) |
| dec = self.decode(quant_b) |
| return dec |
|
|
| def forward_step(self, input): |
| quant, diff, _ = self.encode(input) |
| dec = self.decode(quant) |
| return dec, diff |
|
|
| def feed_data(self, data): |
| x = data['segm'] |
| x = F.one_hot(x, num_classes=self.opt['num_segm_classes']) |
|
|
| if len(x.shape) == 3: |
| x = x[..., None] |
| x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) |
| return x.float().to(self.device) |
|
|
| def get_current_log(self): |
| return self.log_dict |
|
|
| def update_learning_rate(self, epoch): |
| """Update learning rate. |
| |
| Args: |
| current_iter (int): Current iteration. |
| warmup_iter (int): Warmup iter numbers. -1 for no warmup. |
| Default: -1. |
| """ |
| lr = self.optimizer.param_groups[0]['lr'] |
|
|
| if self.opt['lr_decay'] == 'step': |
| lr = self.opt['lr'] * ( |
| self.opt['gamma']**(epoch // self.opt['step'])) |
| elif self.opt['lr_decay'] == 'cos': |
| lr = self.opt['lr'] * ( |
| 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2 |
| elif self.opt['lr_decay'] == 'linear': |
| lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs']) |
| elif self.opt['lr_decay'] == 'linear2exp': |
| if epoch < self.opt['turning_point'] + 1: |
| |
| |
| lr = self.opt['lr'] * ( |
| 1 - epoch / int(self.opt['turning_point'] * 1.0526)) |
| else: |
| lr *= self.opt['gamma'] |
| elif self.opt['lr_decay'] == 'schedule': |
| if epoch in self.opt['schedule']: |
| lr *= self.opt['gamma'] |
| else: |
| raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay'])) |
| |
| for param_group in self.optimizer.param_groups: |
| param_group['lr'] = lr |
|
|
| return lr |
|
|
|
|
| class VQSegmentationModel(VQModel): |
|
|
| def __init__(self, opt): |
| super().__init__(opt) |
| self.colorize = torch.randn(3, opt['num_segm_classes'], 1, |
| 1).to(self.device) |
|
|
| self.init_training_settings() |
|
|
| def configure_optimizers(self): |
| self.optimizer = torch.optim.Adam( |
| list(self.encoder.parameters()) + list(self.decoder.parameters()) + |
| list(self.quantize.parameters()) + |
| list(self.quant_conv.parameters()) + |
| list(self.post_quant_conv.parameters()), |
| lr=self.opt['lr'], |
| betas=(0.5, 0.9)) |
|
|
| def training_step(self, data): |
| x = self.feed_data(data) |
| xrec, qloss = self.forward_step(x) |
| aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train") |
| self.log_dict.update(log_dict_ae) |
| return aeloss |
|
|
| def to_rgb(self, x): |
| x = F.conv2d(x, weight=self.colorize) |
| x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. |
| return x |
|
|
| @torch.no_grad() |
| def inference(self, data_loader, save_dir): |
| self.encoder.eval() |
| self.decoder.eval() |
| self.quantize.eval() |
| self.quant_conv.eval() |
| self.post_quant_conv.eval() |
|
|
| loss_total = 0 |
| loss_bce = 0 |
| loss_quant = 0 |
| num = 0 |
|
|
| for _, data in enumerate(data_loader): |
| img_name = data['img_name'][0] |
| x = self.feed_data(data) |
| xrec, qloss = self.forward_step(x) |
| _, log_dict_ae = self.loss(qloss, x, xrec, split="val") |
|
|
| loss_total += log_dict_ae['val/total_loss'] |
| loss_bce += log_dict_ae['val/bce_loss'] |
| loss_quant += log_dict_ae['val/quant_loss'] |
|
|
| num += x.size(0) |
|
|
| if x.shape[1] > 3: |
| |
| assert xrec.shape[1] > 3 |
| |
| xrec = torch.argmax(xrec, dim=1, keepdim=True) |
| xrec = F.one_hot(xrec, num_classes=x.shape[1]) |
| xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float() |
| x = self.to_rgb(x) |
| xrec = self.to_rgb(xrec) |
|
|
| img_cat = torch.cat([x, xrec], dim=3).detach() |
| img_cat = ((img_cat + 1) / 2) |
| img_cat = img_cat.clamp_(0, 1) |
| save_image( |
| img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4) |
|
|
| return (loss_total / num).item(), (loss_bce / |
| num).item(), (loss_quant / |
| num).item() |
|
|
|
|
| class VQImageModel(VQModel): |
|
|
| def __init__(self, opt): |
| super().__init__(opt) |
| self.disc = Discriminator( |
| opt['n_channels'], opt['ndf'], |
| n_layers=opt['disc_layers']).to(self.device) |
| self.perceptual = lpips.LPIPS(net="vgg").to(self.device) |
| self.perceptual_weight = opt['perceptual_weight'] |
| self.disc_start_step = opt['disc_start_step'] |
| self.disc_weight_max = opt['disc_weight_max'] |
| self.diff_aug = opt['diff_aug'] |
| self.policy = "color,translation" |
|
|
| self.disc.train() |
|
|
| self.init_training_settings() |
|
|
| def feed_data(self, data): |
| x = data['image'] |
|
|
| return x.float().to(self.device) |
|
|
| def init_training_settings(self): |
| self.log_dict = OrderedDict() |
| self.configure_optimizers() |
|
|
| def configure_optimizers(self): |
| self.optimizer = torch.optim.Adam( |
| list(self.encoder.parameters()) + list(self.decoder.parameters()) + |
| list(self.quantize.parameters()) + |
| list(self.quant_conv.parameters()) + |
| list(self.post_quant_conv.parameters()), |
| lr=self.opt['lr']) |
|
|
| self.disc_optimizer = torch.optim.Adam( |
| self.disc.parameters(), lr=self.opt['lr']) |
|
|
| def training_step(self, data, step): |
| x = self.feed_data(data) |
| xrec, codebook_loss = self.forward_step(x) |
|
|
| |
| recon_loss = torch.abs(x.contiguous() - xrec.contiguous()) |
| p_loss = self.perceptual(x.contiguous(), xrec.contiguous()) |
| nll_loss = recon_loss + self.perceptual_weight * p_loss |
| nll_loss = torch.mean(nll_loss) |
|
|
| |
| if self.diff_aug: |
| xrec = DiffAugment(xrec, policy=self.policy) |
|
|
| |
| logits_fake = self.disc(xrec) |
| g_loss = -torch.mean(logits_fake) |
| last_layer = self.decoder.conv_out.weight |
| d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer, |
| self.disc_weight_max) |
| d_weight *= adopt_weight(1, step, self.disc_start_step) |
| loss = nll_loss + d_weight * g_loss + codebook_loss |
|
|
| self.log_dict["loss"] = loss |
| self.log_dict["l1"] = recon_loss.mean().item() |
| self.log_dict["perceptual"] = p_loss.mean().item() |
| self.log_dict["nll_loss"] = nll_loss.item() |
| self.log_dict["g_loss"] = g_loss.item() |
| self.log_dict["d_weight"] = d_weight |
| self.log_dict["codebook_loss"] = codebook_loss.item() |
|
|
| if step > self.disc_start_step: |
| if self.diff_aug: |
| logits_real = self.disc( |
| DiffAugment(x.contiguous().detach(), policy=self.policy)) |
| else: |
| logits_real = self.disc(x.contiguous().detach()) |
| logits_fake = self.disc(xrec.contiguous().detach( |
| )) |
| d_loss = hinge_d_loss(logits_real, logits_fake) |
| self.log_dict["d_loss"] = d_loss |
| else: |
| d_loss = None |
|
|
| return loss, d_loss |
|
|
| def optimize_parameters(self, data, step): |
| self.encoder.train() |
| self.decoder.train() |
| self.quantize.train() |
| self.quant_conv.train() |
| self.post_quant_conv.train() |
|
|
| loss, d_loss = self.training_step(data, step) |
| self.optimizer.zero_grad() |
| loss.backward() |
| self.optimizer.step() |
|
|
| if step > self.disc_start_step: |
| self.disc_optimizer.zero_grad() |
| d_loss.backward() |
| self.disc_optimizer.step() |
|
|
| @torch.no_grad() |
| def inference(self, data_loader, save_dir): |
| self.encoder.eval() |
| self.decoder.eval() |
| self.quantize.eval() |
| self.quant_conv.eval() |
| self.post_quant_conv.eval() |
|
|
| loss_total = 0 |
| num = 0 |
|
|
| for _, data in enumerate(data_loader): |
| img_name = data['img_name'][0] |
| x = self.feed_data(data) |
| xrec, _ = self.forward_step(x) |
|
|
| recon_loss = torch.abs(x.contiguous() - xrec.contiguous()) |
| p_loss = self.perceptual(x.contiguous(), xrec.contiguous()) |
| nll_loss = recon_loss + self.perceptual_weight * p_loss |
| nll_loss = torch.mean(nll_loss) |
| loss_total += nll_loss |
|
|
| num += x.size(0) |
|
|
| if x.shape[1] > 3: |
| |
| assert xrec.shape[1] > 3 |
| |
| xrec = torch.argmax(xrec, dim=1, keepdim=True) |
| xrec = F.one_hot(xrec, num_classes=x.shape[1]) |
| xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float() |
| x = self.to_rgb(x) |
| xrec = self.to_rgb(xrec) |
|
|
| img_cat = torch.cat([x, xrec], dim=3).detach() |
| img_cat = ((img_cat + 1) / 2) |
| img_cat = img_cat.clamp_(0, 1) |
| save_image( |
| img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4) |
|
|
| return (loss_total / num).item() |
|
|
|
|
| class VQImageSegmTextureModel(VQImageModel): |
|
|
| def __init__(self, opt): |
| self.opt = opt |
| self.device = torch.device('cuda') |
| self.encoder = Encoder( |
| ch=opt['ch'], |
| num_res_blocks=opt['num_res_blocks'], |
| attn_resolutions=opt['attn_resolutions'], |
| ch_mult=opt['ch_mult'], |
| in_channels=opt['in_channels'], |
| resolution=opt['resolution'], |
| z_channels=opt['z_channels'], |
| double_z=opt['double_z'], |
| dropout=opt['dropout']).to(self.device) |
| self.decoder = Decoder( |
| in_channels=opt['in_channels'], |
| resolution=opt['resolution'], |
| z_channels=opt['z_channels'], |
| ch=opt['ch'], |
| out_ch=opt['out_ch'], |
| num_res_blocks=opt['num_res_blocks'], |
| attn_resolutions=opt['attn_resolutions'], |
| ch_mult=opt['ch_mult'], |
| dropout=opt['dropout'], |
| resamp_with_conv=True, |
| give_pre_end=False).to(self.device) |
| self.quantize = VectorQuantizerTexture( |
| opt['n_embed'], opt['embed_dim'], beta=0.25).to(self.device) |
| self.quant_conv = torch.nn.Conv2d(opt["z_channels"], opt['embed_dim'], |
| 1).to(self.device) |
| self.post_quant_conv = torch.nn.Conv2d(opt['embed_dim'], |
| opt["z_channels"], |
| 1).to(self.device) |
|
|
| self.disc = Discriminator( |
| opt['n_channels'], opt['ndf'], |
| n_layers=opt['disc_layers']).to(self.device) |
| self.perceptual = lpips.LPIPS(net="vgg").to(self.device) |
| self.perceptual_weight = opt['perceptual_weight'] |
| self.disc_start_step = opt['disc_start_step'] |
| self.disc_weight_max = opt['disc_weight_max'] |
| self.diff_aug = opt['diff_aug'] |
| self.policy = "color,translation" |
|
|
| self.disc.train() |
|
|
| self.init_training_settings() |
|
|
| def feed_data(self, data): |
| x = data['image'].float().to(self.device) |
| mask = data['texture_mask'].float().to(self.device) |
|
|
| return x, mask |
|
|
| def training_step(self, data, step): |
| x, mask = self.feed_data(data) |
| xrec, codebook_loss = self.forward_step(x, mask) |
|
|
| |
| recon_loss = torch.abs(x.contiguous() - xrec.contiguous()) |
| p_loss = self.perceptual(x.contiguous(), xrec.contiguous()) |
| nll_loss = recon_loss + self.perceptual_weight * p_loss |
| nll_loss = torch.mean(nll_loss) |
|
|
| |
| if self.diff_aug: |
| xrec = DiffAugment(xrec, policy=self.policy) |
|
|
| |
| logits_fake = self.disc(xrec) |
| g_loss = -torch.mean(logits_fake) |
| last_layer = self.decoder.conv_out.weight |
| d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer, |
| self.disc_weight_max) |
| d_weight *= adopt_weight(1, step, self.disc_start_step) |
| loss = nll_loss + d_weight * g_loss + codebook_loss |
|
|
| self.log_dict["loss"] = loss |
| self.log_dict["l1"] = recon_loss.mean().item() |
| self.log_dict["perceptual"] = p_loss.mean().item() |
| self.log_dict["nll_loss"] = nll_loss.item() |
| self.log_dict["g_loss"] = g_loss.item() |
| self.log_dict["d_weight"] = d_weight |
| self.log_dict["codebook_loss"] = codebook_loss.item() |
|
|
| if step > self.disc_start_step: |
| if self.diff_aug: |
| logits_real = self.disc( |
| DiffAugment(x.contiguous().detach(), policy=self.policy)) |
| else: |
| logits_real = self.disc(x.contiguous().detach()) |
| logits_fake = self.disc(xrec.contiguous().detach( |
| )) |
| d_loss = hinge_d_loss(logits_real, logits_fake) |
| self.log_dict["d_loss"] = d_loss |
| else: |
| d_loss = None |
|
|
| return loss, d_loss |
|
|
| @torch.no_grad() |
| def inference(self, data_loader, save_dir): |
| self.encoder.eval() |
| self.decoder.eval() |
| self.quantize.eval() |
| self.quant_conv.eval() |
| self.post_quant_conv.eval() |
|
|
| loss_total = 0 |
| num = 0 |
|
|
| for _, data in enumerate(data_loader): |
| img_name = data['img_name'][0] |
| x, mask = self.feed_data(data) |
| xrec, _ = self.forward_step(x, mask) |
|
|
| recon_loss = torch.abs(x.contiguous() - xrec.contiguous()) |
| p_loss = self.perceptual(x.contiguous(), xrec.contiguous()) |
| nll_loss = recon_loss + self.perceptual_weight * p_loss |
| nll_loss = torch.mean(nll_loss) |
| loss_total += nll_loss |
|
|
| num += x.size(0) |
|
|
| if x.shape[1] > 3: |
| |
| assert xrec.shape[1] > 3 |
| |
| xrec = torch.argmax(xrec, dim=1, keepdim=True) |
| xrec = F.one_hot(xrec, num_classes=x.shape[1]) |
| xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float() |
| x = self.to_rgb(x) |
| xrec = self.to_rgb(xrec) |
|
|
| img_cat = torch.cat([x, xrec], dim=3).detach() |
| img_cat = ((img_cat + 1) / 2) |
| img_cat = img_cat.clamp_(0, 1) |
| save_image( |
| img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4) |
|
|
| return (loss_total / num).item() |
|
|
| def encode(self, x, mask): |
| h = self.encoder(x) |
| h = self.quant_conv(h) |
| quant, emb_loss, info = self.quantize(h, mask) |
| return quant, emb_loss, info |
|
|
| def decode(self, quant): |
| quant = self.post_quant_conv(quant) |
| dec = self.decoder(quant) |
| return dec |
|
|
| def decode_code(self, code_b): |
| quant_b = self.quantize.embed_code(code_b) |
| dec = self.decode(quant_b) |
| return dec |
|
|
| def forward_step(self, input, mask): |
| quant, diff, _ = self.encode(input, mask) |
| dec = self.decode(quant) |
| return dec, diff |
|
|