# ================================================================== # L A T E N T D I F F U S I O N M O D E L # ================================================================== # Author : Ashish Kumar Uchadiya # Created : November 3, 2024 # Description: This script implements a Latent Diffusion Model using # a cosine or linear noise scheduling approach for high-resolution # image generation. The model leverages generative techniques to # learn a latent representation and progressively reduce noise to # generate clear, realistic images. # ================================================================== # I M P O R T S # ================================================================== import os os.environ["CUDA_VISIBLE_DEVICES"] = "1" """Lpips""" # from __future__ import absolute_import from collections import namedtuple import torch import torch.nn as nn import torch.nn.init as init from torch.autograd import Variable import numpy as np import torch.nn import torchvision # Taken from https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def spatial_average(in_tens, keepdim=True): return in_tens.mean([2, 3], keepdim=keepdim) class vgg16(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): super(vgg16, self).__init__() vgg_pretrained_features = torchvision.models.vgg16( weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1 ).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.N_slices = 5 for x in range(4): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(4, 9): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(9, 16): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(16, 23): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(23, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) # Freeze vgg model if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): # Return output of vgg features h = self.slice1(X) h_relu1_2 = h h = self.slice2(h) h_relu2_2 = h h = self.slice3(h) h_relu3_3 = h h = self.slice4(h) h_relu4_3 = h h = self.slice5(h) h_relu5_3 = h vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) return out # Learned perceptual metric class LPIPS(nn.Module): def __init__(self, net='vgg', version='0.1', use_dropout=True): super(LPIPS, self).__init__() self.version = version # Imagenet normalization self.scaling_layer = ScalingLayer() ######################## # Instantiate vgg model self.chns = [64, 128, 256, 512, 512] self.L = len(self.chns) self.net = vgg16(pretrained=True, requires_grad=False) # Add 1x1 convolutional Layers self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] self.lins = nn.ModuleList(self.lins) ######################## # Load the weights of trained LPIPS model import inspect import os # /home/taruntejaneurips23/.cache/torch/hub/checkpoints/vgg16-397923af.pth print(os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net)))) # model_path = os.path.abspath( # os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net))) # print('Loading model from: %s' % model_path) # self.load_state_dict(torch.load(model_path, map_location=device), strict=False) ######################## # Freeze all parameters self.eval() for param in self.parameters(): param.requires_grad = False ######################## def forward(self, in0, in1, normalize=False): # Scale the inputs to -1 to +1 range if needed if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] in0 = 2 * in0 - 1 in1 = 2 * in1 - 1 ######################## # Normalize the inputs according to imagenet normalization in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1) ######################## # Get VGG outputs for image0 and image1 outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) feats0, feats1, diffs = {}, {}, {} ######################## # Compute Square of Difference for each layer output for kk in range(self.L): feats0[kk], feats1[kk] = torch.nn.functional.normalize(outs0[kk], dim=1), torch.nn.functional.normalize( outs1[kk]) diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 ######################## # 1x1 convolution followed by spatial average on the square differences res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] val = 0 # Aggregate the results of each layer for l in range(self.L): val += res[l] return val class ScalingLayer(nn.Module): def __init__(self): super(ScalingLayer, self).__init__() # Imagnet normalization for (0-1) # mean = [0.485, 0.456, 0.406] # std = [0.229, 0.224, 0.225] self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) def forward(self, inp): return (inp - self.shift) / self.scale class NetLinLayer(nn.Module): ''' A single linear layer which does a 1x1 conv ''' def __init__(self, chn_in, chn_out=1, use_dropout=False): super(NetLinLayer, self).__init__() layers = [nn.Dropout(), ] if (use_dropout) else [] layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] self.model = nn.Sequential(*layers) def forward(self, x): out = self.model(x) return out """Blocks""" import torch import numpy as np class LinearNoiseScheduler: r""" Class for the linear noise scheduler that is used in DDPM. """ def __init__(self, num_timesteps, beta_start, beta_end): self.num_timesteps = num_timesteps self.beta_start = beta_start self.beta_end = beta_end # Mimicking how compvis repo creates schedule self.betas = ( torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_timesteps) ** 2 ) self.alphas = 1. - self.betas self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0) self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod) self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod) def add_noise(self, original, noise, t): r""" Forward method for diffusion :param original: Image on which noise is to be applied :param noise: Random Noise Tensor (from normal dist) :param t: timestep of the forward process of shape -> (B,) :return: """ original_shape = original.shape batch_size = original_shape[0] sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size) sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size) # Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W) for _ in range(len(original_shape) - 1): sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1) for _ in range(len(original_shape) - 1): sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1) # Apply and Return Forward process equation return (sqrt_alpha_cum_prod.to(original.device) * original + sqrt_one_minus_alpha_cum_prod.to(original.device) * noise) def sample_prev_timestep(self, xt, noise_pred, t): r""" Use the noise prediction by model to get xt-1 using xt and the nosie predicted :param xt: current timestep sample :param noise_pred: model noise prediction :param t: current timestep we are at :return: """ x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) / torch.sqrt(self.alpha_cum_prod.to(xt.device)[t])) x0 = torch.clamp(x0, -1., 1.) mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t]) mean = mean / torch.sqrt(self.alphas.to(xt.device)[t]) if t == 0: return mean, x0 else: variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t]) variance = variance * self.betas.to(xt.device)[t] sigma = variance ** 0.5 z = torch.randn(xt.shape).to(xt.device) # OR # variance = self.betas[t] # sigma = variance ** 0.5 # z = torch.randn(xt.shape).to(xt.device) return mean + sigma * z, x0 import torch import math class CosineNoiseScheduler: r""" Class for the cosine noise scheduler, often used in DDPM-based models. """ def __init__(self, num_timesteps, s=0.008): self.num_timesteps = num_timesteps self.s = s # Cosine schedule based on paper def cosine_schedule(t): return math.cos((t / self.num_timesteps + s) / (1 + s) * math.pi / 2) ** 2 # Compute alphas self.alphas = torch.tensor([cosine_schedule(t) for t in range(num_timesteps)]) self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0) self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod) self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod) def add_noise(self, original, noise, t): original_shape = original.shape batch_size = original_shape[0] sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size) sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size) for _ in range(len(original_shape) - 1): sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1) for _ in range(len(original_shape) - 1): sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1) return (sqrt_alpha_cum_prod * original + sqrt_one_minus_alpha_cum_prod * noise) def sample_prev_timestep(self, xt, noise_pred, t): x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) / torch.sqrt(self.alpha_cum_prod.to(xt.device)[t])) x0 = torch.clamp(x0, -1., 1.) mean = xt - ((1 - self.alphas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t]) mean = mean / torch.sqrt(self.alphas.to(xt.device)[t]) if t == 0: return mean, x0 else: variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t]) variance = variance * (1 - self.alphas.to(xt.device)[t]) sigma = variance ** 0.5 z = torch.randn(xt.shape).to(xt.device) return mean + sigma * z, x0 import torch import torch.nn as nn def get_time_embedding(time_steps, temb_dim): r""" Convert time steps tensor into an embedding using the sinusoidal time embedding formula :param time_steps: 1D tensor of length batch size :param temb_dim: Dimension of the embedding :return: BxD embedding representation of B time steps """ assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" # factor = 10000^(2i/d_model) factor = 10000 ** ((torch.arange( start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2)) ) # pos / factor # timesteps B -> B, 1 -> B, temb_dim t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) return t_emb class DownBlock(nn.Module): r""" Down conv block with attention. Sequence of following block 1. Resnet block with time embedding 2. Attention block 3. Downsample """ def __init__(self, in_channels, out_channels, t_emb_dim, down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None): super().__init__() self.num_layers = num_layers self.down_sample = down_sample self.attn = attn self.context_dim = context_dim self.cross_attn = cross_attn self.t_emb_dim = t_emb_dim self.resnet_conv_first = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), nn.SiLU(), nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) for i in range(num_layers) ] ) if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList([ nn.Sequential( nn.SiLU(), nn.Linear(self.t_emb_dim, out_channels) ) for _ in range(num_layers) ]) self.resnet_conv_second = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, out_channels), nn.SiLU(), nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) for _ in range(num_layers) ] ) if self.attn: self.attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] ) self.attentions = nn.ModuleList( [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)] ) if self.cross_attn: assert context_dim is not None, "Context Dimension must be passed for cross attention" self.cross_attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] ) self.cross_attentions = nn.ModuleList( [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)] ) self.context_proj = nn.ModuleList( [nn.Linear(context_dim, out_channels) for _ in range(num_layers)] ) self.residual_input_conv = nn.ModuleList( [ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) for i in range(num_layers) ] ) self.down_sample_conv = nn.Conv2d(out_channels, out_channels, 4, 2, 1) if self.down_sample else nn.Identity() def forward(self, x, t_emb=None, context=None): out = x for i in range(self.num_layers): # Resnet block of Unet resnet_input = out out = self.resnet_conv_first[i](out) if self.t_emb_dim is not None: out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] out = self.resnet_conv_second[i](out) out = out + self.residual_input_conv[i](resnet_input) if self.attn: # Attention block of Unet batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn if self.cross_attn: assert context is not None, "context cannot be None if cross attention layers are used" batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.cross_attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim context_proj = self.context_proj[i](context) out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn # Downsample out = self.down_sample_conv(out) return out class MidBlock(nn.Module): r""" Mid conv block with attention. Sequence of following blocks 1. Resnet block with time embedding 2. Attention block 3. Resnet block with time embedding """ def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, context_dim=None): super().__init__() self.num_layers = num_layers self.t_emb_dim = t_emb_dim self.context_dim = context_dim self.cross_attn = cross_attn self.resnet_conv_first = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), nn.SiLU(), nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) for i in range(num_layers + 1) ] ) if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList([ nn.Sequential( nn.SiLU(), nn.Linear(t_emb_dim, out_channels) ) for _ in range(num_layers + 1) ]) self.resnet_conv_second = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, out_channels), nn.SiLU(), nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) for _ in range(num_layers + 1) ] ) self.attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] ) self.attentions = nn.ModuleList( [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)] ) if self.cross_attn: assert context_dim is not None, "Context Dimension must be passed for cross attention" self.cross_attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] ) self.cross_attentions = nn.ModuleList( [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)] ) self.context_proj = nn.ModuleList( [nn.Linear(context_dim, out_channels) for _ in range(num_layers)] ) self.residual_input_conv = nn.ModuleList( [ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) for i in range(num_layers + 1) ] ) def forward(self, x, t_emb=None, context=None): out = x # First resnet block resnet_input = out out = self.resnet_conv_first[0](out) if self.t_emb_dim is not None: out = out + self.t_emb_layers[0](t_emb)[:, :, None, None] out = self.resnet_conv_second[0](out) out = out + self.residual_input_conv[0](resnet_input) for i in range(self.num_layers): # Attention Block batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn if self.cross_attn: assert context is not None, "context cannot be None if cross attention layers are used" batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.cross_attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim context_proj = self.context_proj[i](context) out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn # Resnet Block resnet_input = out out = self.resnet_conv_first[i + 1](out) if self.t_emb_dim is not None: out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None] out = self.resnet_conv_second[i + 1](out) out = out + self.residual_input_conv[i + 1](resnet_input) return out class UpBlock(nn.Module): r""" Up conv block with attention. Sequence of following blocks 1. Upsample 1. Concatenate Down block output 2. Resnet block with time embedding 3. Attention Block """ def __init__(self, in_channels, out_channels, t_emb_dim, up_sample, num_heads, num_layers, attn, norm_channels): super().__init__() self.num_layers = num_layers self.up_sample = up_sample self.t_emb_dim = t_emb_dim self.attn = attn self.resnet_conv_first = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), nn.SiLU(), nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) for i in range(num_layers) ] ) if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList([ nn.Sequential( nn.SiLU(), nn.Linear(t_emb_dim, out_channels) ) for _ in range(num_layers) ]) self.resnet_conv_second = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, out_channels), nn.SiLU(), nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) for _ in range(num_layers) ] ) if self.attn: self.attention_norms = nn.ModuleList( [ nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers) ] ) self.attentions = nn.ModuleList( [ nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers) ] ) self.residual_input_conv = nn.ModuleList( [ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) for i in range(num_layers) ] ) self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels, 4, 2, 1) \ if self.up_sample else nn.Identity() def forward(self, x, out_down=None, t_emb=None): # Upsample x = self.up_sample_conv(x) # Concat with Downblock output if out_down is not None: x = torch.cat([x, out_down], dim=1) out = x for i in range(self.num_layers): # Resnet Block resnet_input = out out = self.resnet_conv_first[i](out) if self.t_emb_dim is not None: out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] out = self.resnet_conv_second[i](out) out = out + self.residual_input_conv[i](resnet_input) # Self Attention if self.attn: batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn return out class UpBlockUnet(nn.Module): r""" Up conv block with attention. Sequence of following blocks 1. Upsample 1. Concatenate Down block output 2. Resnet block with time embedding 3. Attention Block """ def __init__(self, in_channels, out_channels, t_emb_dim, up_sample, num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None): super().__init__() self.num_layers = num_layers self.up_sample = up_sample self.t_emb_dim = t_emb_dim self.cross_attn = cross_attn self.context_dim = context_dim self.resnet_conv_first = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), nn.SiLU(), nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) for i in range(num_layers) ] ) if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList([ nn.Sequential( nn.SiLU(), nn.Linear(t_emb_dim, out_channels) ) for _ in range(num_layers) ]) self.resnet_conv_second = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, out_channels), nn.SiLU(), nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) for _ in range(num_layers) ] ) self.attention_norms = nn.ModuleList( [ nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers) ] ) self.attentions = nn.ModuleList( [ nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers) ] ) if self.cross_attn: assert context_dim is not None, "Context Dimension must be passed for cross attention" self.cross_attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] ) self.cross_attentions = nn.ModuleList( [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) for _ in range(num_layers)] ) self.context_proj = nn.ModuleList( [nn.Linear(context_dim, out_channels) for _ in range(num_layers)] ) self.residual_input_conv = nn.ModuleList( [ nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) for i in range(num_layers) ] ) self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 4, 2, 1) \ if self.up_sample else nn.Identity() def forward(self, x, out_down=None, t_emb=None, context=None): x = self.up_sample_conv(x) if out_down is not None: x = torch.cat([x, out_down], dim=1) out = x for i in range(self.num_layers): # Resnet resnet_input = out out = self.resnet_conv_first[i](out) if self.t_emb_dim is not None: out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] out = self.resnet_conv_second[i](out) out = out + self.residual_input_conv[i](resnet_input) # Self Attention batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn # Cross Attention if self.cross_attn: assert context is not None, "context cannot be None if cross attention layers are used" batch_size, channels, h, w = out.shape in_attn = out.reshape(batch_size, channels, h * w) in_attn = self.cross_attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) assert len(context.shape) == 3, \ "Context shape does not match B,_,CONTEXT_DIM" assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim,\ "Context shape does not match B,_,CONTEXT_DIM" context_proj = self.context_proj[i](context) out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) out = out + out_attn return out """Vqvae""" import torch import torch.nn as nn class VQVAE(nn.Module): def __init__(self, im_channels, model_config): super().__init__() self.down_channels = model_config.down_channels self.mid_channels = model_config.mid_channels self.down_sample = model_config.down_sample self.num_down_layers = model_config.num_down_layers self.num_mid_layers = model_config.num_mid_layers self.num_up_layers = model_config.num_up_layers # To disable attention in Downblock of Encoder and Upblock of Decoder self.attns = model_config.attn_down # Latent Dimension self.z_channels = model_config.z_channels self.codebook_size = model_config.codebook_size self.norm_channels = model_config.norm_channels self.num_heads = model_config.num_heads # Assertion to validate the channel information assert self.mid_channels[0] == self.down_channels[-1] assert self.mid_channels[-1] == self.down_channels[-1] assert len(self.down_sample) == len(self.down_channels) - 1 assert len(self.attns) == len(self.down_channels) - 1 # Wherever we use downsampling in encoder correspondingly use # upsampling in decoder self.up_sample = list(reversed(self.down_sample)) ##################### Encoder ###################### self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1)) # Downblock + Midblock self.encoder_layers = nn.ModuleList([]) for i in range(len(self.down_channels) - 1): self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1], t_emb_dim=None, down_sample=self.down_sample[i], num_heads=self.num_heads, num_layers=self.num_down_layers, attn=self.attns[i], norm_channels=self.norm_channels)) self.encoder_mids = nn.ModuleList([]) for i in range(len(self.mid_channels) - 1): self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1], t_emb_dim=None, num_heads=self.num_heads, num_layers=self.num_mid_layers, norm_channels=self.norm_channels)) self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1]) self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], self.z_channels, kernel_size=3, padding=1) # Pre Quantization Convolution self.pre_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) # Codebook self.embedding = nn.Embedding(self.codebook_size, self.z_channels) #################################################### ##################### Decoder ###################### # Post Quantization Convolution self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1)) # Midblock + Upblock self.decoder_mids = nn.ModuleList([]) for i in reversed(range(1, len(self.mid_channels))): self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1], t_emb_dim=None, num_heads=self.num_heads, num_layers=self.num_mid_layers, norm_channels=self.norm_channels)) self.decoder_layers = nn.ModuleList([]) for i in reversed(range(1, len(self.down_channels))): self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1], t_emb_dim=None, up_sample=self.down_sample[i - 1], num_heads=self.num_heads, num_layers=self.num_up_layers, attn=self.attns[i-1], norm_channels=self.norm_channels)) self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0]) self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1) def quantize(self, x): B, C, H, W = x.shape # B, C, H, W -> B, H, W, C x = x.permute(0, 2, 3, 1) # B, H, W, C -> B, H*W, C x = x.reshape(x.size(0), -1, x.size(-1)) # Find nearest embedding/codebook vector # dist between (B, H*W, C) and (B, K, C) -> (B, H*W, K) dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1))) # (B, H*W) min_encoding_indices = torch.argmin(dist, dim=-1) # Replace encoder output with nearest codebook # quant_out -> B*H*W, C quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1)) # x -> B*H*W, C x = x.reshape((-1, x.size(-1))) commmitment_loss = torch.mean((quant_out.detach() - x) ** 2) codebook_loss = torch.mean((quant_out - x.detach()) ** 2) quantize_losses = { 'codebook_loss': codebook_loss, 'commitment_loss': commmitment_loss } # Straight through estimation quant_out = x + (quant_out - x).detach() # quant_out -> B, C, H, W quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2) min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1))) return quant_out, quantize_losses, min_encoding_indices def encode(self, x): out = self.encoder_conv_in(x) for idx, down in enumerate(self.encoder_layers): out = down(out) for mid in self.encoder_mids: out = mid(out) out = self.encoder_norm_out(out) out = nn.SiLU()(out) out = self.encoder_conv_out(out) out = self.pre_quant_conv(out) out, quant_losses, _ = self.quantize(out) return out, quant_losses def decode(self, z): out = z out = self.post_quant_conv(out) out = self.decoder_conv_in(out) for mid in self.decoder_mids: out = mid(out) for idx, up in enumerate(self.decoder_layers): out = up(out) out = self.decoder_norm_out(out) out = nn.SiLU()(out) out = self.decoder_conv_out(out) return out def forward(self, x): z, quant_losses = self.encode(x) out = self.decode(z) return out, z, quant_losses """Vae""" import torch import torch.nn as nn class VAE(nn.Module): def __init__(self, im_channels, model_config): super().__init__() self.down_channels = model_config['down_channels'] self.mid_channels = model_config['mid_channels'] self.down_sample = model_config['down_sample'] self.num_down_layers = model_config['num_down_layers'] self.num_mid_layers = model_config['num_mid_layers'] self.num_up_layers = model_config['num_up_layers'] # To disable attention in Downblock of Encoder and Upblock of Decoder self.attns = model_config['attn_down'] # Latent Dimension self.z_channels = model_config['z_channels'] self.norm_channels = model_config['norm_channels'] self.num_heads = model_config['num_heads'] # Assertion to validate the channel information assert self.mid_channels[0] == self.down_channels[-1] assert self.mid_channels[-1] == self.down_channels[-1] assert len(self.down_sample) == len(self.down_channels) - 1 assert len(self.attns) == len(self.down_channels) - 1 # Wherever we use downsampling in encoder correspondingly use # upsampling in decoder self.up_sample = list(reversed(self.down_sample)) ##################### Encoder ###################### self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1)) # Downblock + Midblock self.encoder_layers = nn.ModuleList([]) for i in range(len(self.down_channels) - 1): self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1], t_emb_dim=None, down_sample=self.down_sample[i], num_heads=self.num_heads, num_layers=self.num_down_layers, attn=self.attns[i], norm_channels=self.norm_channels)) self.encoder_mids = nn.ModuleList([]) for i in range(len(self.mid_channels) - 1): self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1], t_emb_dim=None, num_heads=self.num_heads, num_layers=self.num_mid_layers, norm_channels=self.norm_channels)) self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1]) self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], 2*self.z_channels, kernel_size=3, padding=1) # Latent Dimension is 2*Latent because we are predicting mean & variance self.pre_quant_conv = nn.Conv2d(2*self.z_channels, 2*self.z_channels, kernel_size=1) #################################################### ##################### Decoder ###################### self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1)) # Midblock + Upblock self.decoder_mids = nn.ModuleList([]) for i in reversed(range(1, len(self.mid_channels))): self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1], t_emb_dim=None, num_heads=self.num_heads, num_layers=self.num_mid_layers, norm_channels=self.norm_channels)) self.decoder_layers = nn.ModuleList([]) for i in reversed(range(1, len(self.down_channels))): self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1], t_emb_dim=None, up_sample=self.down_sample[i - 1], num_heads=self.num_heads, num_layers=self.num_up_layers, attn=self.attns[i - 1], norm_channels=self.norm_channels)) self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0]) self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1) def encode(self, x): out = self.encoder_conv_in(x) for idx, down in enumerate(self.encoder_layers): out = down(out) for mid in self.encoder_mids: out = mid(out) out = self.encoder_norm_out(out) out = nn.SiLU()(out) out = self.encoder_conv_out(out) out = self.pre_quant_conv(out) mean, logvar = torch.chunk(out, 2, dim=1) std = torch.exp(0.5 * logvar) sample = mean + std * torch.randn(mean.shape).to(device=x.device) return sample, out def decode(self, z): out = z out = self.post_quant_conv(out) out = self.decoder_conv_in(out) for mid in self.decoder_mids: out = mid(out) for idx, up in enumerate(self.decoder_layers): out = up(out) out = self.decoder_norm_out(out) out = nn.SiLU()(out) out = self.decoder_conv_out(out) return out def forward(self, x): z, encoder_output = self.encode(x) out = self.decode(z) return out, encoder_output """Discriminator""" import torch import torch.nn as nn class Discriminator(nn.Module): r""" PatchGAN Discriminator. Rather than taking IMG_CHANNELSxIMG_HxIMG_W all the way to 1 scalar value , we instead predict grid of values. Where each grid is prediction of how likely the discriminator thinks that the image patch corresponding to the grid cell is real """ def __init__(self, im_channels=3, conv_channels=[64, 128, 256], kernels=[4,4,4,4], strides=[2,2,2,1], paddings=[1,1,1,1]): super().__init__() self.im_channels = im_channels activation = nn.LeakyReLU(0.2) layers_dim = [self.im_channels] + conv_channels + [1] self.layers = nn.ModuleList([ nn.Sequential( nn.Conv2d(layers_dim[i], layers_dim[i + 1], kernel_size=kernels[i], stride=strides[i], padding=paddings[i], bias=False if i !=0 else True), nn.BatchNorm2d(layers_dim[i + 1]) if i != len(layers_dim) - 2 and i != 0 else nn.Identity(), activation if i != len(layers_dim) - 2 else nn.Identity() ) for i in range(len(layers_dim) - 1) ]) def forward(self, x): out = x for layer in self.layers: out = layer(out) return out # if __name__ == '__main__': # x = torch.randn((2,3, 256, 256)) # prob = Discriminator(im_channels=3)(x) # print(prob.shape) # import os # image_paths = [os.path.join("/home/taruntejaneurips23/Ashish/datasets/animefacedata/images", f) # for f in os.listdir("/home/taruntejaneurips23/Ashish/datasets/animefacedata/images")] # image_paths import glob import os import torchvision from PIL import Image from tqdm import tqdm, trange # from utils.diffusion_utils import load_latents from torch.utils.data.dataset import Dataset import pickle import glob import os import torch def load_latents(latent_path): r""" Simple utility to save latents to speed up ldm training :param latent_path: :return: """ latent_maps = {} for fname in glob.glob(os.path.join(latent_path, '*.pkl')): s = pickle.load(open(fname, 'rb')) for k, v in s.items(): latent_maps[k] = v[0] return latent_maps def drop_text_condition(text_embed, im, empty_text_embed, text_drop_prob): if text_drop_prob > 0: text_drop_mask = torch.zeros((im.shape[0]), device=im.device).float().uniform_(0, 1) < text_drop_prob assert empty_text_embed is not None, ("Text Conditioning required as well as" " text dropping but empty text representation not created") text_embed[text_drop_mask, :, :] = empty_text_embed[0] return text_embed def drop_image_condition(image_condition, im, im_drop_prob): if im_drop_prob > 0: im_drop_mask = torch.zeros((im.shape[0], 1, 1, 1), device=im.device).float().uniform_(0, 1) > im_drop_prob return image_condition * im_drop_mask else: return image_condition def drop_class_condition(class_condition, class_drop_prob, im): if class_drop_prob > 0: class_drop_mask = torch.zeros((im.shape[0], 1), device=im.device).float().uniform_(0, 1) > class_drop_prob return class_condition * class_drop_mask else: return class_condition class MnistDataset(Dataset): r""" Nothing special here. Just a simple dataset class for mnist images. Created a dataset class rather using torchvision to allow replacement with any other image dataset """ def __init__(self, split, im_path, im_size, im_channels, use_latents=False, latent_path=None, condition_config=None): r""" Init method for initializing the dataset properties :param split: train/test to locate the image files :param im_path: root folder of images :param im_ext: image extension. assumes all images would be this type. """ self.split = split self.im_size = im_size self.im_channels = im_channels # Should we use latents or not self.latent_maps = None self.use_latents = False # Conditioning for the dataset self.condition_types = [] if condition_config is None else condition_config['condition_types'] self.images, self.labels = self.load_images(im_path) # Whether to load images and call vae or to load latents if use_latents and latent_path is not None: latent_maps = load_latents(latent_path) if len(latent_maps) == len(self.images): self.use_latents = True self.latent_maps = latent_maps print('Found {} latents'.format(len(self.latent_maps))) else: print('Latents not found') def load_images(self, im_path): r""" Gets all images from the path specified and stacks them all up :param im_path: :return: """ assert os.path.exists(im_path), "images path {} does not exist".format(im_path) ims = [] labels = [] for d_name in tqdm(os.listdir(im_path)): fnames = glob.glob(os.path.join(im_path, d_name, '*.{}'.format('png'))) fnames += glob.glob(os.path.join(im_path, d_name, '*.{}'.format('jpg'))) fnames += glob.glob(os.path.join(im_path, d_name, '*.{}'.format('jpeg'))) for fname in fnames: ims.append(fname) if 'class' in self.condition_types: labels.append(int(d_name)) print('Found {} images for split {}'.format(len(ims), self.split)) return ims, labels def __len__(self): return len(self.images) def __getitem__(self, index): ######## Set Conditioning Info ######## cond_inputs = {} if 'class' in self.condition_types: cond_inputs['class'] = self.labels[index] ####################################### if self.use_latents: latent = self.latent_maps[self.images[index]] if len(self.condition_types) == 0: return latent else: return latent, cond_inputs else: im = Image.open(self.images[index]) im_tensor = torchvision.transforms.ToTensor()(im) # Convert input to -1 to 1 range. im_tensor = (2 * im_tensor) - 1 if len(self.condition_types) == 0: return im_tensor else: return im_tensor, cond_inputs class AnimeFaceDataset(Dataset): def __init__(self, split, im_path, im_size, im_channels, use_latents=False, latent_path=None, condition_config=None): self.split = split self.im_size = im_size self.im_channels = im_channels # Should we use latents or not self.latent_maps = None self.use_latents = False # Conditioning for the dataset self.condition_types = [] if condition_config is None else condition_config['condition_types'] self.images = self.load_images(im_path) # Whether to load images and call vae or to load latents if use_latents and latent_path is not None: latent_maps = load_latents(latent_path) if len(latent_maps) == len(self.images): self.use_latents = True self.latent_maps = latent_maps print('Found {} latents'.format(len(self.latent_maps))) else: print('Latents not found') def load_images(self, im_path): r""" Gets all images from the path specified and stacks them all up :param im_path: :return: """ assert os.path.exists(im_path), "images path {} does not exist".format(im_path) # ims = [] # labels = [] ims = [os.path.join(im_path, f) for f in os.listdir(im_path)] return ims def __len__(self): return len(self.images) def __getitem__(self, index): ######## Set Conditioning Info ######## # cond_inputs = {} # if 'class' in self.condition_types: # cond_inputs['class'] = self.labels[index] ####################################### if self.use_latents: latent = self.latent_maps[self.images[index]] if len(self.condition_types) == 0: return latent # else: # return latent, cond_inputs else: im = Image.open(self.images[index]) im_tensor = torchvision.transforms.Compose([ torchvision.transforms.Resize(self.im_size), torchvision.transforms.CenterCrop(self.im_size), torchvision.transforms.ToTensor(), ])(im) im.close() # im_tensor = torchvision.transforms.ToTensor()(im) # Convert input to -1 to 1 range. im_tensor = (2 * im_tensor) - 1 if len(self.condition_types) == 0: return im_tensor # else: # return im_tensor, cond_inputs import glob import os import random import torch import torchvision import numpy as np from PIL import Image from tqdm import tqdm from torch.utils.data.dataset import Dataset class CelebDataset(Dataset): def __init__(self, split, im_path, im_size, im_channels, use_latents=False, latent_path=None, condition_config=None): self.split = split self.im_size = im_size self.im_channels = im_channels # Should we use latents or not self.latent_maps = None self.use_latents = False # Conditioning for the dataset self.condition_types = [] if condition_config is None else condition_config['condition_types'] self.images = self.load_images(im_path) # Whether to load images and call vae or to load latents if use_latents and latent_path is not None: latent_maps = load_latents(latent_path) if len(latent_maps) == len(self.images): self.use_latents = True self.latent_maps = latent_maps print('Found {} latents'.format(len(self.latent_maps))) else: print('Latents not found') def load_images(self, im_path): r""" Gets all images from the path specified and stacks them all up :param im_path: :return: """ assert os.path.exists(im_path), "images path {} does not exist".format(im_path) # ims = [] # labels = [] ims = [os.path.join(im_path, f) for f in os.listdir(im_path)] return ims def __len__(self): return len(self.images) def __getitem__(self, index): ######## Set Conditioning Info ######## # cond_inputs = {} # if 'class' in self.condition_types: # cond_inputs['class'] = self.labels[index] ####################################### if self.use_latents: latent = self.latent_maps[self.images[index]] if len(self.condition_types) == 0: return latent # else: # return latent, cond_inputs else: im = Image.open(self.images[index]) im_tensor = torchvision.transforms.Compose([ # torchvision.transforms.Resize(self.im_size), torchvision.transforms.CenterCrop(self.im_size), torchvision.transforms.ToTensor(), ])(im) im.close() # im_tensor = torchvision.transforms.ToTensor()(im) # Convert input to -1 to 1 range. im_tensor = (2 * im_tensor) - 1 if len(self.condition_types) == 0: return im_tensor # else: # return im_tensor, cond_inputs import pandas as pd class CelebHairDataset(Dataset): def __init__(self, split, im_path, im_size, im_channels, use_latents=False, latent_path=None, condition_config=None): self.df = pd.read_csv("/home/taruntejaneurips23/Ashish/DDPM/hair_df_100.csv") self.split = split self.im_size = im_size self.im_channels = im_channels # Should we use latents or not self.latent_maps = None self.use_latents = False # Conditioning for the dataset self.condition_types = [] if condition_config is None else condition_config['condition_types'] self.images = self.load_images(im_path, self.df) # Whether to load images and call vae or to load latents if use_latents and latent_path is not None: latent_maps = load_latents(latent_path) if len(latent_maps) == len(self.images): self.use_latents = True self.latent_maps = latent_maps print('Found {} latents'.format(len(self.latent_maps))) else: print('Latents not found') def load_images(self, im_path, df): r""" Gets all images from the path specified and stacks them all up :param im_path: :return: """ assert os.path.exists(im_path), "images path {} does not exist".format(im_path) # ims = [] # labels = [] # ims = [os.path.join(im_path, f) for f in os.listdir(im_path)] ims = [os.path.join(im_path, i) for i in df.image_id.values] return ims def __len__(self): return len(self.images) def __getitem__(self, index): ######## Set Conditioning Info ######## # cond_inputs = {} # if 'class' in self.condition_types: # cond_inputs['class'] = self.labels[index] ####################################### if self.use_latents: latent = self.latent_maps[self.images[index]] if len(self.condition_types) == 0: return latent # else: # return latent, cond_inputs else: im = Image.open(self.images[index]) im_tensor = torchvision.transforms.Compose([ # torchvision.transforms.Resize(self.im_size), torchvision.transforms.CenterCrop(self.im_size), torchvision.transforms.ToTensor(), ])(im) im.close() # im_tensor = torchvision.transforms.ToTensor()(im) # Convert input to -1 to 1 range. im_tensor = (2 * im_tensor) - 1 if len(self.condition_types) == 0: return im_tensor # else: # return im_tensor, cond_inputs #"""Train VQVAE"""............................................................................................................................................... # Commented out IPython magic to ensure Python compatibility. import torch import torch.nn as nn import yaml from dotdict import DotDict config_path = "/home/taruntejaneurips23/Ashish/DDPM/_5_ldm_celeba.yaml" with open(config_path, 'r') as file: Config = yaml.safe_load(file) Config = DotDict.from_dict(Config) dataset_config = Config.dataset_params diffusion_config = Config.diffusion_params model_config = Config.model_params train_config = Config.train_params import torch import os import random import numpy as np import matplotlib.pyplot as plt from tqdm import tqdm from torch.optim import Adam from torch.utils.data import Dataset, TensorDataset, DataLoader # device = 'cuda:1' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu' from torchvision.utils import make_grid def trainVAE(Config): dataset_config = Config.dataset_params autoencoder_config = Config.autoencoder_params train_config = Config.train_params # Set the desired seed value # seed = train_config.seed torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) if device == 'cuda': torch.cuda.manual_seed_all(seed) ############################# # Create the model and dataset # model = VQVAE(im_channels=dataset_config.im_channels, model_config=autoencoder_config).to(device) # model.load_state_dict(torch.load("/home/taruntejaneurips23/Ashish/DDPM/celebAhair_ldm/vqvae_autoencoder_ckpt.pth", map_location=device)) if os.path.exists(os.path.join(train_config.task_name, train_config.vqvae_autoencoder_ckpt_name)): print('Loaded vae checkpoint') model.load_state_dict(torch.load(os.path.join(train_config.task_name, train_config.vqvae_autoencoder_ckpt_name), map_location=device, weights_only=True)) # Create the dataset im_dataset_cls = { 'mnist': MnistDataset, 'celebA': CelebDataset, 'animeface': AnimeFaceDataset, 'celebAhair': CelebHairDataset }.get(dataset_config.name) im_dataset = im_dataset_cls(split='train', im_path=dataset_config.im_path, im_size=dataset_config.im_size, im_channels=dataset_config.im_channels) data_loader = DataLoader(im_dataset, batch_size=train_config.autoencoder_batch_size, shuffle=True, num_workers=os.cpu_count(), pin_memory=True, drop_last=True, persistent_workers=True, pin_memory_device=device) # Create output directories if not os.path.exists(train_config.task_name): os.mkdir(train_config.task_name) num_epochs = train_config.autoencoder_epochs # L1/L2 loss for Reconstruction recon_criterion = torch.nn.MSELoss() # Disc Loss can even be BCEWithLogits disc_criterion = torch.nn.MSELoss() # No need to freeze lpips as lpips.py takes care of that lpips_model = LPIPS().eval().to(device) discriminator = Discriminator(im_channels=dataset_config.im_channels).to(device) # discriminator.load_state_dict(torch.load("/home/taruntejaneurips23/Ashish/DDPM/celebAhair_ldm/vqvae_discriminator_ckpt.pth", map_location=device)) if os.path.exists(os.path.join(train_config.task_name, train_config.vqvae_discriminator_ckpt_name)): print('Loaded discriminator checkpoint') discriminator.load_state_dict(torch.load(os.path.join(train_config.task_name, train_config.vqvae_discriminator_ckpt_name), map_location=device, weights_only=True)) optimizer_d = Adam(discriminator.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999)) optimizer_g = Adam(model.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999)) disc_step_start = train_config.disc_start step_count = 0 # This is for accumulating gradients incase the images are huge # And one cant afford higher batch sizes acc_steps = train_config.autoencoder_acc_steps image_save_steps = train_config.autoencoder_img_save_steps img_save_count = 0 for epoch_idx in trange(num_epochs, desc='Training VQVAE'): recon_losses = [] codebook_losses = [] #commitment_losses = [] perceptual_losses = [] disc_losses = [] gen_losses = [] losses = [] optimizer_g.zero_grad() optimizer_d.zero_grad() # for im in tqdm(data_loader): for im in data_loader: step_count += 1 im = im.float().to(device) # Fetch autoencoders output(reconstructions) model_output = model(im) output, z, quantize_losses = model_output # Image Saving Logic if step_count % image_save_steps == 0 or step_count == 1: sample_size = min(8, im.shape[0]) save_output = torch.clamp(output[:sample_size], -1., 1.).detach().cpu() save_output = ((save_output + 1) / 2) save_input = ((im[:sample_size] + 1) / 2).detach().cpu() grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size) img = torchvision.transforms.ToPILImage()(grid) if not os.path.exists(os.path.join(train_config.task_name,'vqvae_autoencoder_samples')): os.mkdir(os.path.join(train_config.task_name, 'vqvae_autoencoder_samples')) img.save(os.path.join(train_config.task_name,'vqvae_autoencoder_samples', 'current_autoencoder_sample_{}.png'.format(img_save_count))) img_save_count += 1 img.close() ######### Optimize Generator ########## # L2 Loss recon_loss = recon_criterion(output, im) recon_losses.append(recon_loss.item()) recon_loss = recon_loss / acc_steps g_loss = (recon_loss + (train_config.codebook_weight * quantize_losses['codebook_loss'] / acc_steps) + (train_config.commitment_beta * quantize_losses['commitment_loss'] / acc_steps)) codebook_losses.append(train_config.codebook_weight * quantize_losses['codebook_loss'].item()) # Adversarial loss only if disc_step_start steps passed if step_count > disc_step_start: disc_fake_pred = discriminator(model_output[0]) disc_fake_loss = disc_criterion(disc_fake_pred, torch.ones(disc_fake_pred.shape, device=disc_fake_pred.device)) gen_losses.append(train_config.disc_weight * disc_fake_loss.item()) g_loss += train_config.disc_weight * disc_fake_loss / acc_steps lpips_loss = torch.mean(lpips_model(output, im)) / acc_steps perceptual_losses.append(train_config.perceptual_weight * lpips_loss.item()) g_loss += train_config.perceptual_weight*lpips_loss / acc_steps losses.append(g_loss.item()) g_loss.backward() ##################################### ######### Optimize Discriminator ####### if step_count > disc_step_start: fake = output disc_fake_pred = discriminator(fake.detach()) disc_real_pred = discriminator(im) disc_fake_loss = disc_criterion(disc_fake_pred, torch.zeros(disc_fake_pred.shape, device=disc_fake_pred.device)) disc_real_loss = disc_criterion(disc_real_pred, torch.ones(disc_real_pred.shape, device=disc_real_pred.device)) disc_loss = train_config.disc_weight * (disc_fake_loss + disc_real_loss) / 2 disc_losses.append(disc_loss.item()) disc_loss = disc_loss / acc_steps disc_loss.backward() if step_count % acc_steps == 0: optimizer_d.step() optimizer_d.zero_grad() ##################################### if step_count % acc_steps == 0: optimizer_g.step() optimizer_g.zero_grad() optimizer_d.step() optimizer_d.zero_grad() optimizer_g.step() optimizer_g.zero_grad() if len(disc_losses) > 0: print( 'Finished epoch: {}/{} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | ' 'Codebook : {:.4f} | G Loss : {:.4f} | D Loss {:.4f}'. format(epoch_idx + 1, num_epochs, np.mean(recon_losses), np.mean(perceptual_losses), np.mean(codebook_losses), np.mean(gen_losses), np.mean(disc_losses))) else: print('Finished epoch: {}/{} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | Codebook : {:.4f}'. format(epoch_idx + 1, num_epochs, np.mean(recon_losses), np.mean(perceptual_losses), np.mean(codebook_losses))) torch.save(model.state_dict(), os.path.join(train_config.task_name, train_config.vqvae_autoencoder_ckpt_name)) torch.save(discriminator.state_dict(), os.path.join(train_config.task_name, train_config.vqvae_discriminator_ckpt_name)) print('Done Training...') # trainVAE(Config) import torch import torch.nn as nn class Unet(nn.Module): r""" Unet model comprising Down blocks, Midblocks and Uplocks """ def __init__(self, im_channels, model_config): super().__init__() self.down_channels = model_config.down_channels self.mid_channels = model_config.mid_channels self.t_emb_dim = model_config.time_emb_dim self.down_sample = model_config.down_sample self.num_down_layers = model_config.num_down_layers self.num_mid_layers = model_config.num_mid_layers self.num_up_layers = model_config.num_up_layers self.attns = model_config.attn_down self.norm_channels = model_config.norm_channels self.num_heads = model_config.num_heads self.conv_out_channels = model_config.conv_out_channels assert self.mid_channels[0] == self.down_channels[-1] assert self.mid_channels[-1] == self.down_channels[-2] assert len(self.down_sample) == len(self.down_channels) - 1 assert len(self.attns) == len(self.down_channels) - 1 # Initial projection from sinusoidal time embedding self.t_proj = nn.Sequential( nn.Linear(self.t_emb_dim, self.t_emb_dim), nn.SiLU(), nn.Linear(self.t_emb_dim, self.t_emb_dim) ) self.up_sample = list(reversed(self.down_sample)) self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=1) self.downs = nn.ModuleList([]) for i in range(len(self.down_channels) - 1): self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i + 1], self.t_emb_dim, down_sample=self.down_sample[i], num_heads=self.num_heads, num_layers=self.num_down_layers, attn=self.attns[i], norm_channels=self.norm_channels)) self.mids = nn.ModuleList([]) for i in range(len(self.mid_channels) - 1): self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1], self.t_emb_dim, num_heads=self.num_heads, num_layers=self.num_mid_layers, norm_channels=self.norm_channels)) self.ups = nn.ModuleList([]) for i in reversed(range(len(self.down_channels) - 1)): self.ups.append(UpBlockUnet(self.down_channels[i] * 2, self.down_channels[i - 1] if i != 0 else self.conv_out_channels, self.t_emb_dim, up_sample=self.down_sample[i], num_heads=self.num_heads, num_layers=self.num_up_layers, norm_channels=self.norm_channels)) self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels) self.conv_out = nn.Conv2d(self.conv_out_channels, im_channels, kernel_size=3, padding=1) def forward(self, x, t): # Shapes assuming downblocks are [C1, C2, C3, C4] # Shapes assuming midblocks are [C4, C4, C3] # Shapes assuming downsamples are [True, True, False] # B x C x H x W out = self.conv_in(x) # B x C1 x H x W # t_emb -> B x t_emb_dim t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim) t_emb = self.t_proj(t_emb) down_outs = [] for idx, down in enumerate(self.downs): down_outs.append(out) out = down(out, t_emb) # down_outs [B x C1 x H x W, B x C2 x H/2 x W/2, B x C3 x H/4 x W/4] # out B x C4 x H/4 x W/4 for mid in self.mids: out = mid(out, t_emb) # out B x C3 x H/4 x W/4 for up in self.ups: down_out = down_outs.pop() out = up(out, down_out, t_emb) # out [B x C2 x H/4 x W/4, B x C1 x H/2 x W/2, B x 16 x H x W] out = self.norm_out(out) out = nn.SiLU()(out) out = self.conv_out(out) # out B x C x H x W return out def trainLDM(Config): diffusion_config = Config.diffusion_params dataset_config = Config.dataset_params diffusion_model_config = Config.ldm_params autoencoder_model_config = Config.autoencoder_params train_config = Config.train_params # Create the noise scheduler scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config.num_timesteps, beta_start=diffusion_config.beta_start, beta_end=diffusion_config.beta_end) # scheduler = CosineNoiseScheduler(diffusion_config.num_timesteps) im_dataset_cls = { 'mnist': MnistDataset, 'celebA': CelebDataset, 'animeface': AnimeFaceDataset, 'celebAhair': CelebHairDataset }.get(dataset_config.name) im_dataset = im_dataset_cls(split='train', im_path=dataset_config.im_path, im_size=dataset_config.im_size, im_channels=dataset_config.im_channels, use_latents=True, latent_path=os.path.join(train_config.task_name, train_config.vqvae_latent_dir_name) ) data_loader = DataLoader(im_dataset, batch_size=train_config.ldm_batch_size, shuffle=True, num_workers=os.cpu_count(), pin_memory=True, drop_last=False, persistent_workers=True, pin_memory_device=device) # Instantiate the model model = Unet(im_channels=autoencoder_model_config.z_channels, model_config=diffusion_model_config).to(device) if os.path.exists(os.path.join(train_config.task_name, train_config.ldm_ckpt_name)): print('Loaded ldm checkpoint') model.load_state_dict(torch.load(os.path.join(train_config.task_name, train_config.ldm_ckpt_name), map_location=device, weights_only=True)) model.train() # Load VAE ONLY if latents are not to be used or are missing if not im_dataset.use_latents: print('Loading vqvae model as latents not present') vae = VQVAE(im_channels=dataset_config.im_channels, model_config=autoencoder_model_config).to(device) vae.eval() # Load vae if found if os.path.exists(os.path.join(train_config.task_name, train_config.vqvae_autoencoder_ckpt_name)): print('Loaded vae checkpoint') vae.load_state_dict(torch.load(os.path.join(train_config.task_name, train_config.vqvae_autoencoder_ckpt_name), map_location=device)) # Specify training parameters num_epochs = train_config.ldm_epochs optimizer = Adam(model.parameters(), lr=train_config.ldm_lr) criterion = torch.nn.MSELoss() # Run training if not im_dataset.use_latents: for param in vae.parameters(): param.requires_grad = False for epoch_idx in range(num_epochs): losses = [] for im in tqdm(data_loader): optimizer.zero_grad() im = im.float().to(device) if not im_dataset.use_latents: with torch.no_grad(): im, _ = vae.encode(im) # Sample random noise noise = torch.randn_like(im).to(device) # Sample timestep t = torch.randint(0, diffusion_config.num_timesteps, (im.shape[0],)).to(device) # Add noise to images according to timestep noisy_im = scheduler.add_noise(im, noise, t) noise_pred = model(noisy_im, t) loss = criterion(noise_pred, noise) losses.append(loss.item()) loss.backward() optimizer.step() print(f'Finished epoch:{epoch_idx + 1}/{num_epochs} | Loss : {np.mean(losses):.4f}') torch.save(model.state_dict(), os.path.join(train_config.task_name, train_config.ldm_ckpt_name)) # Doing Inference infer(Config) # Checking to conntinue training train_continue = yaml.safe_load(open("/home/taruntejaneurips23/Ashish/DDPM/_5_ldm_celeba.yaml", 'r')) train_continue = DotDict.from_dict(train_continue) if train_continue.training._continue_ == False: print('Training Stoped ...') break print('Done Training ...') # trainLDM(Config) # import subprocess # subprocess.run(f'kill {os.getpid()}', shell=True, check=True) def sample(model, scheduler, train_config, diffusion_model_config, autoencoder_model_config, diffusion_config, dataset_config, vae): r""" Sample stepwise by going backward one timestep at a time. We save the x0 predictions """ im_size = dataset_config.im_size // 2**sum(autoencoder_model_config.down_sample) xt = torch.randn((train_config.num_samples, autoencoder_model_config.z_channels, im_size, im_size)).to(device) save_count = 0 for i in tqdm(reversed(range(diffusion_config.num_timesteps)), total=diffusion_config.num_timesteps): # Get prediction of noise noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device)) # Use scheduler to get x0 and xt-1 xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device)) # Save x0 #ims = torch.clamp(xt, -1., 1.).detach().cpu() if i == 0: # Decode ONLY the final iamge to save time ims = vae.decode(xt) else: ims = xt ims = torch.clamp(ims, -1., 1.).detach().cpu() ims = (ims + 1) / 2 grid = make_grid(ims, nrow=train_config.num_grid_rows) img = torchvision.transforms.ToPILImage()(grid) if not os.path.exists(os.path.join(train_config.task_name, 'samples')): os.mkdir(os.path.join(train_config.task_name, 'samples')) img.save(os.path.join(train_config.task_name, 'samples', 'x0_{}.png'.format(i))) img.close() def infer(Config): diffusion_config = Config.diffusion_params dataset_config = Config.dataset_params diffusion_model_config = Config.ldm_params autoencoder_model_config = Config.autoencoder_params train_config = Config.train_params # Create the noise scheduler scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config.num_timesteps, beta_start=diffusion_config.beta_start, beta_end=diffusion_config.beta_end) # scheduler = CosineNoiseScheduler(diffusion_config.num_timesteps) model = Unet(im_channels=autoencoder_model_config.z_channels, model_config=diffusion_model_config).to(device) model.eval() if os.path.exists(os.path.join(train_config.task_name, train_config.ldm_ckpt_name)): print('Loaded unet checkpoint') model.load_state_dict(torch.load(os.path.join(train_config.task_name, train_config.ldm_ckpt_name), map_location=device)) # Create output directories if not os.path.exists(train_config.task_name): os.mkdir(train_config.task_name) vae = VQVAE(im_channels=dataset_config.im_channels, model_config=autoencoder_model_config).to(device) vae.eval() # Load vae if found if os.path.exists(os.path.join(train_config.task_name, train_config.vqvae_autoencoder_ckpt_name)): print('Loaded vae checkpoint') vae.load_state_dict(torch.load(os.path.join(train_config.task_name, train_config.vqvae_autoencoder_ckpt_name), map_location=device), strict=True) with torch.no_grad(): sample(model, scheduler, train_config, diffusion_model_config, autoencoder_model_config, diffusion_config, dataset_config, vae) import argparse def get_args(): parser = argparse.ArgumentParser(description="Choose between train VAE, train LDM, or infer mode.") parser.add_argument('--mode', choices=['train_vae', 'train_ldm', 'infer'], default='infer', help="Mode to run: train_vae, train_ldm, or infer") return parser.parse_args() args = get_args() if args.mode == 'train_vae': trainVAE(Config) elif args.mode == 'train_ldm': trainLDM(Config) else: infer(Config) # python _5.2_ldm_celeba_hair_cosine.py --mode train_vae # python _5.2_ldm_celeba_hair_cosine.py --mode train_ldm # python _5.2_ldm_celeba_hair_cosine.py --mode infer # import matplotlib.pyplot as plt # from PIL import Image # # plt.style.use('dark_background') # # %matplotlib inline # plt.imshow(Image.open('/home/taruntejaneurips23/Ashish/DDPM/mnist_ldm/samples/x0_0.png'), cmap='gray') # import matplotlib.pyplot as plt # import matplotlib.image as mpimg # dataset_name = 'animeface_ldm' # image_paths = [f'/home/taruntejaneurips23/Ashish/DDPM/{dataset_name}/samples/x0_0.png', # f'/home/taruntejaneurips23/Ashish/DDPM/{dataset_name}/samples/x0_1.png', # f'/home/taruntejaneurips23/Ashish/DDPM/{dataset_name}/samples/x0_5.png', # f'/home/taruntejaneurips23/Ashish/DDPM/{dataset_name}/samples/x0_100.png', # f'/home/taruntejaneurips23/Ashish/DDPM/{dataset_name}/samples/x0_200.png' # ] # fig, axes = plt.subplots(1, len(image_paths), figsize=(15, 5)) # for i, path in enumerate(image_paths): # img = mpimg.imread(path) # axes[i].imshow(img) # axes[i].axis('off') # Hide axes # axes[i].set_title(f't = {path.split("/")[-1].split(".")[0].split("_")[-1]}') # plt.tight_layout() # plt.show() # --------------------------------------------------------- # ---------- T H E - E N D ------------------------------- # --------------------------------------------------------- def save_checkpoint( total_steps, epoch, model, discriminator, optimizer_d, optimizer_g, loss, checkpoint_path ): checkpoint = { "total_steps": total_steps, "epoch": epoch, "model_state_dict": model.state_dict(), "discriminator_state_dict": discriminator.state_dict(), "optimizer_d_state_dict": optimizer_d.state_dict(), "optimizer_g_state_dict": optimizer_g.state_dict(), "loss": loss, } torch.save(checkpoint, checkpoint_path) print(f"Checkpoint saved after {total_steps} steps at epoch {epoch}") def load_checkpoint( checkpoint_path, model, discriminator, optimizer_d, optimizer_g ): if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint["model_state_dict"]) discriminator.load_state_dict(checkpoint["discriminator_state_dict"]) optimizer_d.load_state_dict(checkpoint["optimizer_d_state_dict"]) optimizer_g.load_state_dict(checkpoint["optimizer_g_state_dict"]) total_steps = checkpoint["total_steps"] start_epoch = checkpoint["epoch"] + 1 loss = checkpoint["loss"] print(f"Checkpoint loaded. Resuming from epoch {start_epoch}") return total_steps, start_epoch, loss else: print("No checkpoint found. Starting from scratch.") return 0, 0, None def trainVAE(Config, dataloader): """ Trains a VQVAE model using the provided configuration and data loader. """ # --- Configurations ---------------------------------------------------- dataset_config = Config.dataset_params autoencoder_config = Config.autoencoder_params train_config = Config.train_params seed = train_config.seed torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) if device == "cuda": torch.cuda.manual_seed_all(seed) # --- Model Initialization ---------------------------------------------- model = VQVAE(im_channels=dataset_config.im_channels, model_config=autoencoder_config).to(device) discriminator = Discriminator(im_channels=dataset_config.im_channels).to(device) # --- Load Checkpoints -------------------------------------------------- checkpoint_path = os.path.join(train_config.task_name, "vqvae_checkpoint.pth") total_steps, start_epoch, _ = load_checkpoint(checkpoint_path, model, discriminator, None, None) # --- Loss Function Initialization -------------------------------------- recon_criterion = torch.nn.MSELoss() lpips_model = LPIPS().eval().to(device) disc_criterion = torch.nn.MSELoss() # --- Optimizer Initialization ------------------------------------------ optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999)) optimizer_g = torch.optim.AdamW(model.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999)) num_epochs = train_config.autoencoder_epochs acc_steps = train_config.autoencoder_acc_steps image_save_steps = train_config.autoencoder_img_save_steps img_save_count = 0 # Create necessary directories os.makedirs(os.path.join(train_config.task_name, "vqvae_autoencoder_samples"), exist_ok=True) # --- Training Loop ----------------------------------------------------- for epoch_idx in range(start_epoch, num_epochs): recon_losses, codebook_losses, perceptual_losses, disc_losses, gen_losses = [], [], [], [], [] for images in dataloader: total_steps += 1 images = images.to(device) # Forward pass model_output = model(images) output, z, quantize_losses = model_output # Save generated images periodically if total_steps % image_save_steps == 0 or total_steps == 1: sample_size = min(8, images.shape[0]) save_output = torch.clamp(output[:sample_size], -1.0, 1.0).detach().cpu() save_output = (save_output + 1) / 2 save_input = ((images[:sample_size] + 1) / 2).detach().cpu() grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size) img = tv.transforms.ToPILImage()(grid) img.save( os.path.join( train_config.task_name, "vqvae_autoencoder_samples", f"current_autoencoder_sample_{img_save_count}.png", ) ) img_save_count += 1 img.close() # Reconstruction Loss recon_loss = recon_criterion(output, images) / acc_steps recon_losses.append(recon_loss.item()) # Generator Loss codebook_loss = train_config.codebook_weight * quantize_losses["codebook_loss"] / acc_steps perceptual_loss = train_config.perceptual_weight * lpips_model(output, images).mean() / acc_steps g_loss = recon_loss + codebook_loss + perceptual_loss if total_steps > train_config.disc_start: disc_fake_pred = discriminator(output) gen_loss = train_config.disc_weight * disc_criterion( disc_fake_pred, torch.ones_like(disc_fake_pred) ) / acc_steps g_loss += gen_loss gen_losses.append(gen_loss.item()) g_loss.backward() optimizer_g.step() optimizer_g.zero_grad() # Discriminator Loss if total_steps > train_config.disc_start: disc_fake_pred = discriminator(output.detach()) disc_real_pred = discriminator(images) disc_fake_loss = disc_criterion( disc_fake_pred, torch.zeros_like(disc_fake_pred) ) / acc_steps disc_real_loss = disc_criterion( disc_real_pred, torch.ones_like(disc_real_pred) ) / acc_steps disc_loss = train_config.disc_weight * (disc_fake_loss + disc_real_loss) / 2 disc_loss.backward() optimizer_d.step() optimizer_d.zero_grad() disc_losses.append(disc_loss.item()) # Save checkpoint after each epoch save_checkpoint(total_steps, epoch_idx, model, discriminator, optimizer_d, optimizer_g, recon_losses, checkpoint_path) # Print epoch summary print( f"Epoch {epoch_idx + 1}/{num_epochs} | Recon Loss: {np.mean(recon_losses):.4f} | " f"Perceptual Loss: {np.mean(perceptual_losses):.4f} | Codebook Loss: {np.mean(codebook_losses):.4f} | " f"G Loss: {np.mean(gen_losses):.4f} | D Loss: {np.mean(disc_losses):.4f}" )