| import numpy as np |
| import einops |
| import torch |
| import torch as th |
| import torch.nn as nn |
| from torch.nn import functional as thf |
| import pytorch_lightning as pl |
| import torchvision |
| from copy import deepcopy |
| from ldm.modules.diffusionmodules.util import ( |
| conv_nd, |
| linear, |
| zero_module, |
| timestep_embedding, |
| ) |
| from contextlib import contextmanager, nullcontext |
| from einops import rearrange, repeat |
| from torchvision.utils import make_grid |
| from ldm.modules.attention import SpatialTransformer |
| from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock |
| from ldm.models.diffusion.ddpm import LatentDiffusion |
| from ldm.util import log_txt_as_img, exists, instantiate_from_config, default |
| from ldm.models.diffusion.ddim import DDIMSampler |
| from ldm.modules.ema import LitEma |
| from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution |
| from ldm.modules.diffusionmodules.model import Encoder |
| import lpips |
| import kornia |
| from kornia import color |
|
|
| def disabled_train(self, mode=True): |
| """Overwrite model.train with this function to make sure train/eval mode |
| does not change anymore.""" |
| return self |
|
|
| class View(nn.Module): |
| def __init__(self, *shape): |
| super().__init__() |
| self.shape = shape |
|
|
| def forward(self, x): |
| return x.view(*self.shape) |
|
|
|
|
| class SecretEncoder3(nn.Module): |
| def __init__(self, secret_len, base_res=16, resolution=64) -> None: |
| super().__init__() |
| log_resolution = int(np.log2(resolution)) |
| log_base = int(np.log2(base_res)) |
| self.secret_len = secret_len |
| self.secret_scaler = nn.Sequential( |
| nn.Linear(secret_len, base_res*base_res*3), |
| nn.SiLU(), |
| View(-1, 3, base_res, base_res), |
| nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), |
| zero_module(conv_nd(2, 3, 3, 3, padding=1)) |
| ) |
| |
| def copy_encoder_weight(self, ae_model): |
| |
| return None |
|
|
| def encode(self, x): |
| x = self.secret_scaler(x) |
| return x |
| |
| def forward(self, x, c): |
| |
| c = self.encode(c) |
| return c, None |
|
|
|
|
| class SecretEncoder4(nn.Module): |
| """same as SecretEncoder3 but with ch as input""" |
| def __init__(self, secret_len, ch=3, base_res=16, resolution=64) -> None: |
| super().__init__() |
| log_resolution = int(np.log2(resolution)) |
| log_base = int(np.log2(base_res)) |
| self.secret_len = secret_len |
| self.secret_scaler = nn.Sequential( |
| nn.Linear(secret_len, base_res*base_res*ch), |
| nn.SiLU(), |
| View(-1, ch, base_res, base_res), |
| nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), |
| zero_module(conv_nd(2, ch, ch, 3, padding=1)) |
| ) |
| |
| def copy_encoder_weight(self, ae_model): |
| |
| return None |
|
|
| def encode(self, x): |
| x = self.secret_scaler(x) |
| return x |
| |
| def forward(self, x, c): |
| |
| c = self.encode(c) |
| return c, None |
| |
| class SecretEncoder6(nn.Module): |
| """join img emb with secret emb""" |
| def __init__(self, secret_len, ch=3, base_res=16, resolution=64, emode='c3') -> None: |
| super().__init__() |
| assert emode in ['c3', 'c2', 'm3'] |
| |
| if emode == 'c3': |
| secret_ch = ch |
| join_ch = 2*ch |
| elif emode == 'c2': |
| secret_ch = 2 |
| join_ch = ch |
| elif emode == 'm3': |
| secret_ch = ch |
| join_ch = ch |
| |
| |
| log_resolution = int(np.log2(resolution)) |
| log_base = int(np.log2(base_res)) |
| self.secret_len = secret_len |
| self.emode = emode |
| self.resolution = resolution |
| self.secret_scaler = nn.Sequential( |
| nn.Linear(secret_len, base_res*base_res*secret_ch), |
| nn.SiLU(), |
| View(-1, secret_ch, base_res, base_res), |
| nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), |
| ) |
| self.join_encoder = nn.Sequential( |
| conv_nd(2, join_ch, join_ch, 3, padding=1), |
| nn.SiLU(), |
| conv_nd(2, join_ch, ch, 3, padding=1), |
| nn.SiLU(), |
| conv_nd(2, ch, ch, 3, padding=1), |
| nn.SiLU() |
| ) |
| self.out_layer = zero_module(conv_nd(2, ch, ch, 3, padding=1)) |
| |
| def copy_encoder_weight(self, ae_model): |
| |
| return None |
|
|
| def encode(self, x): |
| x = self.secret_scaler(x) |
| return x |
| |
| def forward(self, x, c): |
| |
| c = self.encode(c) |
| if self.emode == 'c3': |
| x = torch.cat([x, c], dim=1) |
| elif self.emode == 'c2': |
| x = torch.cat([x.mean(dim=1, keepdim=True), c], dim=1) |
| elif self.emode == 'm3': |
| x = x * c |
| dx = self.join_encoder(x) |
| dx = self.out_layer(dx) |
| return dx, None |
| |
| class SecretEncoder5(nn.Module): |
| """same as SecretEncoder3 but with ch as input""" |
| def __init__(self, secret_len, ch=3, base_res=16, resolution=64, joint=False) -> None: |
| super().__init__() |
| log_resolution = int(np.log2(resolution)) |
| log_base = int(np.log2(base_res)) |
| self.secret_len = secret_len |
| self.joint = joint |
| self.resolution = resolution |
| self.secret_scaler = nn.Sequential( |
| nn.Linear(secret_len, base_res*base_res*ch), |
| nn.SiLU(), |
| View(-1, ch, base_res, base_res), |
| nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), |
| ) |
| if joint: |
| self.join_encoder = nn.Sequential( |
| conv_nd(2, 2*ch, 2*ch, 3, padding=1), |
| nn.SiLU(), |
| conv_nd(2, 2*ch, ch, 3, padding=1), |
| nn.SiLU() |
| ) |
| self.out_layer = zero_module(conv_nd(2, ch, ch, 3, padding=1)) |
| |
| def copy_encoder_weight(self, ae_model): |
| |
| return None |
|
|
| def encode(self, x): |
| x = self.secret_scaler(x) |
| return x |
| |
| def forward(self, x, c): |
| |
| c = self.encode(c) |
| if self.joint: |
| x = thf.interpolate(x, size=(self.resolution, self.resolution), mode="bilinear", align_corners=False, antialias=True) |
| c = self.join_encoder(torch.cat([x, c], dim=1)) |
| c = self.out_layer(c) |
| return c, None |
|
|
|
|
| class SecretEncoder2(nn.Module): |
| def __init__(self, secret_len, embed_dim, ddconfig, ckpt_path=None, |
| ignore_keys=[], |
| image_key="image", |
| colorize_nlabels=None, |
| monitor=None, |
| ema_decay=None, |
| learn_logvar=False) -> None: |
| super().__init__() |
| log_resolution = int(np.log2(ddconfig.resolution)) |
| self.secret_len = secret_len |
| self.learn_logvar = learn_logvar |
| self.image_key = image_key |
| self.encoder = Encoder(**ddconfig) |
| self.encoder.conv_out = zero_module(self.encoder.conv_out) |
| self.embed_dim = embed_dim |
|
|
| if colorize_nlabels is not None: |
| assert type(colorize_nlabels)==int |
| self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) |
|
|
| if monitor is not None: |
| self.monitor = monitor |
|
|
| self.secret_scaler = nn.Sequential( |
| nn.Linear(secret_len, 32*32*ddconfig.out_ch), |
| nn.SiLU(), |
| View(-1, ddconfig.out_ch, 32, 32), |
| nn.Upsample(scale_factor=(2**(log_resolution-5), 2**(log_resolution-5))), |
| |
| ) |
| |
| |
|
|
| self.use_ema = ema_decay is not None |
| if self.use_ema: |
| self.ema_decay = ema_decay |
| assert 0. < ema_decay < 1. |
| self.model_ema = LitEma(self, decay=ema_decay) |
| print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") |
|
|
| if ckpt_path is not None: |
| self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) |
|
|
|
|
| def init_from_ckpt(self, path, ignore_keys=list()): |
| sd = torch.load(path, map_location="cpu")["state_dict"] |
| keys = list(sd.keys()) |
| for k in keys: |
| for ik in ignore_keys: |
| if k.startswith(ik): |
| print("Deleting key {} from state_dict.".format(k)) |
| del sd[k] |
| misses, ignores = self.load_state_dict(sd, strict=False) |
| print(f"[SecretEncoder] Restored from {path}, misses: {misses}, ignores: {ignores}") |
|
|
| def copy_encoder_weight(self, ae_model): |
| |
| return None |
| self.encoder.load_state_dict(ae_model.encoder.state_dict()) |
| self.quant_conv.load_state_dict(ae_model.quant_conv.state_dict()) |
|
|
| @contextmanager |
| def ema_scope(self, context=None): |
| if self.use_ema: |
| self.model_ema.store(self.parameters()) |
| self.model_ema.copy_to(self) |
| if context is not None: |
| print(f"{context}: Switched to EMA weights") |
| try: |
| yield None |
| finally: |
| if self.use_ema: |
| self.model_ema.restore(self.parameters()) |
| if context is not None: |
| print(f"{context}: Restored training weights") |
| |
| def on_train_batch_end(self, *args, **kwargs): |
| if self.use_ema: |
| self.model_ema(self) |
|
|
| def encode(self, x): |
| h = self.encoder(x) |
| posterior = h |
| return posterior |
| |
| def forward(self, x, c): |
| |
| c = self.secret_scaler(c) |
| x = torch.cat([x, c], dim=1) |
| z = self.encode(x) |
| |
| return z, None |
|
|
|
|
| class SecretEncoder7(nn.Module): |
| def __init__(self, secret_len, ddconfig, ckpt_path=None, |
| ignore_keys=[],embed_dim=3, |
| ema_decay=None) -> None: |
| super().__init__() |
| log_resolution = int(np.log2(ddconfig.resolution)) |
| self.secret_len = secret_len |
| self.encoder = Encoder(**ddconfig) |
| |
| self.quant_conv = nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) |
|
|
| self.secret_scaler = nn.Sequential( |
| nn.Linear(secret_len, 32*32*2), |
| nn.SiLU(), |
| View(-1, 2, 32, 32), |
| |
| |
| ) |
| |
| |
|
|
| self.use_ema = ema_decay is not None |
| if self.use_ema: |
| self.ema_decay = ema_decay |
| assert 0. < ema_decay < 1. |
| self.model_ema = LitEma(self, decay=ema_decay) |
| print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") |
|
|
| if ckpt_path is not None: |
| self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) |
|
|
|
|
| def init_from_ckpt(self, path, ignore_keys=list()): |
| sd = torch.load(path, map_location="cpu")["state_dict"] |
| keys = list(sd.keys()) |
| for k in keys: |
| for ik in ignore_keys: |
| if k.startswith(ik): |
| print("Deleting key {} from state_dict.".format(k)) |
| del sd[k] |
| misses, ignores = self.load_state_dict(sd, strict=False) |
| print(f"[SecretEncoder7] Restored from {path}, misses: {len(misses)}, ignores: {len(ignores)}. Do not worry as we are not using the decoder and the secret encoder is a novel module.") |
|
|
| def copy_encoder_weight(self, ae_model): |
| |
| |
| self.encoder.load_state_dict(deepcopy(ae_model.encoder.state_dict())) |
| self.quant_conv.load_state_dict(deepcopy(ae_model.quant_conv.state_dict())) |
|
|
| @contextmanager |
| def ema_scope(self, context=None): |
| if self.use_ema: |
| self.model_ema.store(self.parameters()) |
| self.model_ema.copy_to(self) |
| if context is not None: |
| print(f"{context}: Switched to EMA weights") |
| try: |
| yield None |
| finally: |
| if self.use_ema: |
| self.model_ema.restore(self.parameters()) |
| if context is not None: |
| print(f"{context}: Restored training weights") |
| |
| def on_train_batch_end(self, *args, **kwargs): |
| if self.use_ema: |
| self.model_ema(self) |
|
|
| def encode(self, x): |
| h = self.encoder(x) |
| h = self.quant_conv(h) |
| return h |
| |
| def forward(self, x, c): |
| |
| c = self.secret_scaler(c) |
| |
| c = thf.interpolate(c, size=x.shape[-2:], mode="nearest") |
| x = 0.2125 * x[:,0,...] + 0.7154 *x[:,1,...] + 0.0721 * x[:,2,...] |
| x = torch.cat([x.unsqueeze(1), c], dim=1) |
| z = self.encode(x) |
| |
| return z, None |
|
|
| class SecretEncoder(nn.Module): |
| def __init__(self, secret_len, embed_dim, ddconfig, ckpt_path=None, |
| ignore_keys=[], |
| image_key="image", |
| colorize_nlabels=None, |
| monitor=None, |
| ema_decay=None, |
| learn_logvar=False) -> None: |
| super().__init__() |
| log_resolution = int(np.log2(ddconfig.resolution)) |
| self.secret_len = secret_len |
| self.learn_logvar = learn_logvar |
| self.image_key = image_key |
| self.encoder = Encoder(**ddconfig) |
| assert ddconfig["double_z"] |
| self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) |
| self.embed_dim = embed_dim |
|
|
| if colorize_nlabels is not None: |
| assert type(colorize_nlabels)==int |
| self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) |
|
|
| if monitor is not None: |
| self.monitor = monitor |
|
|
| self.use_ema = ema_decay is not None |
| if self.use_ema: |
| self.ema_decay = ema_decay |
| assert 0. < ema_decay < 1. |
| self.model_ema = LitEma(self, decay=ema_decay) |
| print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") |
|
|
| if ckpt_path is not None: |
| self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) |
|
|
| self.secret_scaler = nn.Sequential( |
| nn.Linear(secret_len, 32*32*ddconfig.out_ch), |
| nn.SiLU(), |
| View(-1, ddconfig.out_ch, 32, 32), |
| nn.Upsample(scale_factor=(2**(log_resolution-5), 2**(log_resolution-5))), |
| zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1)) |
| ) |
| |
| self.out_layer = zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1)) |
|
|
| def init_from_ckpt(self, path, ignore_keys=list()): |
| sd = torch.load(path, map_location="cpu")["state_dict"] |
| keys = list(sd.keys()) |
| for k in keys: |
| for ik in ignore_keys: |
| if k.startswith(ik): |
| print("Deleting key {} from state_dict.".format(k)) |
| del sd[k] |
| misses, ignores = self.load_state_dict(sd, strict=False) |
| print(f"[SecretEncoder] Restored from {path}, misses: {misses}, ignores: {ignores}") |
|
|
| def copy_encoder_weight(self, ae_model): |
| |
| self.encoder.load_state_dict(ae_model.encoder.state_dict()) |
| self.quant_conv.load_state_dict(ae_model.quant_conv.state_dict()) |
|
|
| @contextmanager |
| def ema_scope(self, context=None): |
| if self.use_ema: |
| self.model_ema.store(self.parameters()) |
| self.model_ema.copy_to(self) |
| if context is not None: |
| print(f"{context}: Switched to EMA weights") |
| try: |
| yield None |
| finally: |
| if self.use_ema: |
| self.model_ema.restore(self.parameters()) |
| if context is not None: |
| print(f"{context}: Restored training weights") |
| |
| def on_train_batch_end(self, *args, **kwargs): |
| if self.use_ema: |
| self.model_ema(self) |
|
|
| def encode(self, x): |
| h = self.encoder(x) |
| moments = self.quant_conv(h) |
| posterior = DiagonalGaussianDistribution(moments) |
| return posterior |
| |
| def forward(self, x, c): |
| |
| c = self.secret_scaler(c) |
| x = x + c |
| posterior = self.encode(x) |
| z = posterior.sample() |
| z = self.out_layer(z) |
| return z, posterior |
|
|
|
|
| class ControlAE(pl.LightningModule): |
| def __init__(self, |
| first_stage_key, |
| first_stage_config, |
| control_key, |
| control_config, |
| decoder_config, |
| loss_config, |
| noise_config='__none__', |
| use_ema=False, |
| secret_warmup=False, |
| scale_factor=1., |
| ckpt_path="__none__", |
| ): |
| super().__init__() |
| self.scale_factor = scale_factor |
| self.control_key = control_key |
| self.first_stage_key = first_stage_key |
| self.ae = instantiate_from_config(first_stage_config) |
| self.control = instantiate_from_config(control_config) |
| self.decoder = instantiate_from_config(decoder_config) |
| self.crop = kornia.augmentation.CenterCrop((224, 224), cropping_mode="resample") |
| if noise_config != '__none__': |
| print('Using noise') |
| self.noise = instantiate_from_config(noise_config) |
| |
| self.control.copy_encoder_weight(self.ae) |
| |
| self.ae.eval() |
| self.ae.train = disabled_train |
| for p in self.ae.parameters(): |
| p.requires_grad = False |
|
|
| self.loss_layer = instantiate_from_config(loss_config) |
|
|
| |
| |
| self.fixed_x = None |
| self.fixed_img = None |
| self.fixed_input_recon = None |
| self.fixed_control = None |
| self.register_buffer("fixed_input", torch.tensor(True)) |
|
|
| |
| self.secret_warmup = secret_warmup |
| self.secret_baselen = 2 |
| self.secret_len = control_config.params.secret_len |
| if self.secret_warmup: |
| assert self.secret_len == 2**(int(np.log2(self.secret_len))) |
|
|
| self.use_ema = use_ema |
| if self.use_ema: |
| print('Using EMA') |
| self.control_ema = LitEma(self.control) |
| self.decoder_ema = LitEma(self.decoder) |
| print(f"Keeping EMAs of {len(list(self.control_ema.buffers()) + list(self.decoder_ema.buffers()))}.") |
|
|
| if ckpt_path != '__none__': |
| self.init_from_ckpt(ckpt_path, ignore_keys=[]) |
|
|
| def get_warmup_secret(self, old_secret): |
| |
| |
| if self.secret_warmup: |
| bsz = old_secret.shape[0] |
| nrepeats = self.secret_len // self.secret_baselen |
| new_secret = torch.zeros((bsz, self.secret_baselen), dtype=torch.float).random_(0, 2).repeat_interleave(nrepeats, dim=1) |
| return new_secret.to(old_secret.device) |
| else: |
| return old_secret |
| |
| def init_from_ckpt(self, path, ignore_keys=list()): |
| sd = torch.load(path, map_location="cpu")["state_dict"] |
| keys = list(sd.keys()) |
| for k in keys: |
| for ik in ignore_keys: |
| if k.startswith(ik): |
| print("Deleting key {} from state_dict.".format(k)) |
| del sd[k] |
| self.load_state_dict(sd, strict=False) |
| print(f"Restored from {path}") |
|
|
| @contextmanager |
| def ema_scope(self, context=None): |
| if self.use_ema: |
| self.control_ema.store(self.control.parameters()) |
| self.decoder_ema.store(self.decoder.parameters()) |
| self.control_ema.copy_to(self.control) |
| self.decoder_ema.copy_to(self.decoder) |
| if context is not None: |
| print(f"{context}: Switched to EMA weights") |
| try: |
| yield None |
| finally: |
| if self.use_ema: |
| self.control_ema.restore(self.control.parameters()) |
| self.decoder_ema.restore(self.decoder.parameters()) |
| if context is not None: |
| print(f"{context}: Restored training weights") |
|
|
| def on_train_batch_end(self, *args, **kwargs): |
| if self.use_ema: |
| self.control_ema(self.control) |
| self.decoder_ema(self.decoder) |
|
|
| def compute_loss(self, pred, target): |
| |
| lpips_loss = self.lpips_loss(pred, target).mean(dim=[1,2,3]) |
| pred_yuv = color.rgb_to_yuv((pred + 1) / 2) |
| target_yuv = color.rgb_to_yuv((target + 1) / 2) |
| yuv_loss = torch.mean((pred_yuv - target_yuv)**2, dim=[2,3]) |
| yuv_loss = 1.5*torch.mm(yuv_loss, self.yuv_scales).squeeze(1) |
| return lpips_loss + yuv_loss |
|
|
| def forward(self, x, image, c): |
| if self.control.__class__.__name__ == 'SecretEncoder6': |
| eps, posterior = self.control(x, c) |
| else: |
| eps, posterior = self.control(image, c) |
| return x + eps, posterior |
|
|
| @torch.no_grad() |
| def get_input(self, batch, return_first_stage=False, bs=None): |
| image = batch[self.first_stage_key] |
| control = batch[self.control_key] |
| control = self.get_warmup_secret(control) |
| if bs is not None: |
| image = image[:bs] |
| control = control[:bs] |
| else: |
| bs = image.shape[0] |
| |
| image = einops.rearrange(image, "b h w c -> b c h w").contiguous() |
| x = self.encode_first_stage(image).detach() |
| image_rec = self.decode_first_stage(x).detach() |
| |
| |
| |
| if self.fixed_input: |
| if self.fixed_x is None: |
| print('[TRAINING] Warmup - using fixed input image for now!') |
| self.fixed_x = x.detach().clone()[:bs] |
| self.fixed_img = image.detach().clone()[:bs] |
| self.fixed_input_recon = image_rec.detach().clone()[:bs] |
| self.fixed_control = control.detach().clone()[:bs] |
| x, image, image_rec = self.fixed_x, self.fixed_img, self.fixed_input_recon |
| |
| out = [x, control] |
| if return_first_stage: |
| out.extend([image, image_rec]) |
| return out |
|
|
| def decode_first_stage(self, z): |
| z = 1./self.scale_factor * z |
| image_rec = self.ae.decode(z) |
| return image_rec |
| |
| def encode_first_stage(self, image): |
| encoder_posterior = self.ae.encode(image) |
| if isinstance(encoder_posterior, DiagonalGaussianDistribution): |
| z = encoder_posterior.sample() |
| elif isinstance(encoder_posterior, torch.Tensor): |
| z = encoder_posterior |
| else: |
| raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") |
| return self.scale_factor * z |
|
|
| def shared_step(self, batch): |
| x, c, img, _ = self.get_input(batch, return_first_stage=True) |
| |
| x, posterior = self(x, img, c) |
| image_rec = self.decode_first_stage(x) |
| |
| if img.shape[-1] > 256: |
| img = thf.interpolate(img, size=(256, 256), mode='bilinear', align_corners=False).detach() |
| image_rec = thf.interpolate(image_rec, size=(256, 256), mode='bilinear', align_corners=False) |
| if hasattr(self, 'noise') and self.noise.is_activated(): |
| image_rec_noised = self.noise(image_rec, self.global_step, p=0.9) |
| else: |
| image_rec_noised = self.crop(image_rec) |
| image_rec_noised = torch.clamp(image_rec_noised, -1, 1) |
| pred = self.decoder(image_rec_noised) |
|
|
| loss, loss_dict = self.loss_layer(img, image_rec, posterior, c, pred, self.global_step) |
| bit_acc = loss_dict["bit_acc"] |
|
|
| bit_acc_ = bit_acc.item() |
|
|
| if (bit_acc_ > 0.98) and (not self.fixed_input) and self.noise.is_activated(): |
| self.loss_layer.activate_ramp(self.global_step) |
|
|
| if (bit_acc_ > 0.95) and (not self.fixed_input): |
| if hasattr(self, 'noise') and (not self.noise.is_activated()): |
| self.noise.activate(self.global_step) |
|
|
| if (bit_acc_ > 0.9) and self.fixed_input: |
| print(f'[TRAINING] High bit acc ({bit_acc_}) achieved, switch to full image dataset training.') |
| self.fixed_input = ~self.fixed_input |
| return loss, loss_dict |
|
|
| def training_step(self, batch, batch_idx): |
| loss, loss_dict = self.shared_step(batch) |
| loss_dict = {f"train/{key}": val for key, val in loss_dict.items()} |
| self.log_dict(loss_dict, prog_bar=True, |
| logger=True, on_step=True, on_epoch=True) |
| |
| self.log("global_step", self.global_step, |
| prog_bar=True, logger=True, on_step=True, on_epoch=False) |
| |
| |
| |
|
|
| return loss |
|
|
| @torch.no_grad() |
| def validation_step(self, batch, batch_idx): |
| _, loss_dict_no_ema = self.shared_step(batch) |
| loss_dict_no_ema = {f"val/{key}": val for key, val in loss_dict_no_ema.items() if key != 'img_lw'} |
| with self.ema_scope(): |
| _, loss_dict_ema = self.shared_step(batch) |
| loss_dict_ema = {'val/' + key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} |
| self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) |
| self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) |
| |
| @torch.no_grad() |
| def log_images(self, batch, fixed_input=False, **kwargs): |
| log = dict() |
| if fixed_input and self.fixed_img is not None: |
| x, c, img, img_recon = self.fixed_x, self.fixed_control, self.fixed_img, self.fixed_input_recon |
| else: |
| x, c, img, img_recon = self.get_input(batch, return_first_stage=True) |
| x, _ = self(x, img, c) |
| image_out = self.decode_first_stage(x) |
| if hasattr(self, 'noise') and self.noise.is_activated(): |
| img_noise = self.noise(image_out, self.global_step, p=1.0) |
| log['noised'] = img_noise |
| log['input'] = img |
| log['output'] = image_out |
| log['recon'] = img_recon |
| return log |
| |
| def configure_optimizers(self): |
| lr = self.learning_rate |
| params = list(self.control.parameters()) + list(self.decoder.parameters()) |
| optimizer = torch.optim.AdamW(params, lr=lr) |
| return optimizer |
| |
|
|
|
|
|
|
|
|
|
|