| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import os |
|
|
| import torch |
| import torch.nn as nn |
| import numpy as np |
| from collections import namedtuple |
|
|
| import pandas as pd |
| import torchvision as tv |
| from torchvision.transforms import v2 |
| from tqdm import tqdm, trange |
| import matplotlib.pyplot as plt |
|
|
| import re |
| import glob |
| import sys |
| import yaml |
| import random |
| import datetime |
| import torch.hub |
| from torch.utils.data import Dataset, DataLoader |
| from torchvision.utils import make_grid |
|
|
| print("TIME:", datetime.datetime.now()) |
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print("DEVICE:", device) |
|
|
|
|
| |
| |
| |
| from typing import Any |
| from argparse import Namespace |
| import typing |
|
|
|
|
| class DotDict(Namespace): |
| """A simple class that builds upon `argparse.Namespace` |
| in order to make chained attributes possible.""" |
|
|
| def __init__(self, temp=False, key=None, parent=None) -> None: |
| self._temp = temp |
| self._key = key |
| self._parent = parent |
|
|
| def __eq__(self, other): |
| if not isinstance(other, DotDict): |
| return NotImplemented |
| return vars(self) == vars(other) |
|
|
| def __getattr__(self, __name: str) -> Any: |
| if __name not in self.__dict__ and not self._temp: |
| self.__dict__[__name] = DotDict(temp=True, key=__name, parent=self) |
| else: |
| del self._parent.__dict__[self._key] |
| raise AttributeError("No attribute '%s'" % __name) |
| return self.__dict__[__name] |
|
|
| def __repr__(self) -> str: |
| item_keys = [k for k in self.__dict__ if not k.startswith("_")] |
|
|
| if len(item_keys) == 0: |
| return "DotDict()" |
| elif len(item_keys) == 1: |
| key = item_keys[0] |
| val = self.__dict__[key] |
| return "DotDict(%s=%s)" % (key, repr(val)) |
| else: |
| return "DotDict(%s)" % ", ".join( |
| "%s=%s" % (key, repr(val)) for key, val in self.__dict__.items() |
| ) |
|
|
| @classmethod |
| def from_dict(cls, original: typing.Mapping[str, any]) -> "DotDict": |
| """Create a DotDict from a (possibly nested) dict `original`. |
| Warning: this method should not be used on very deeply nested inputs, |
| since it's recursively traversing the nested dictionary values. |
| """ |
| dd = DotDict() |
| for key, value in original.items(): |
| if isinstance(value, typing.Mapping): |
| value = cls.from_dict(value) |
| setattr(dd, key, value) |
| return dd |
| |
| |
| |
| |
| |
| class vgg16(nn.Module): |
| def __init__(self): |
| super(vgg16, self).__init__() |
| vgg_pretrained_features = tv.models.vgg16( |
| weights=tv.models.VGG16_Weights.IMAGENET1K_V1 |
| ).features |
| self.slice1 = torch.nn.Sequential() |
| self.slice2 = torch.nn.Sequential() |
| self.slice3 = torch.nn.Sequential() |
| self.slice4 = torch.nn.Sequential() |
| self.slice5 = torch.nn.Sequential() |
| self.N_slices = 5 |
| for x in range(4): |
| self.slice1.add_module(str(x), vgg_pretrained_features[x]) |
| for x in range(4, 9): |
| self.slice2.add_module(str(x), vgg_pretrained_features[x]) |
| for x in range(9, 16): |
| self.slice3.add_module(str(x), vgg_pretrained_features[x]) |
| for x in range(16, 23): |
| self.slice4.add_module(str(x), vgg_pretrained_features[x]) |
| for x in range(23, 30): |
| self.slice5.add_module(str(x), vgg_pretrained_features[x]) |
| |
| self.eval() |
| for param in self.parameters(): |
| param.requires_grad = False |
|
|
| def forward(self, X): |
| h1 = self.slice1(X) |
| h2 = self.slice2(h1) |
| h3 = self.slice3(h2) |
| h4 = self.slice4(h3) |
| h5 = self.slice5(h4) |
| vgg_outputs = namedtuple("VggOutputs", ['h1', 'h2', 'h3', 'h4', 'h5']) |
| out = vgg_outputs(h1, h2, h3, h4, h5) |
| return out |
|
|
|
|
| def _spatial_average(in_tens, keepdim=True): |
| return in_tens.mean([2, 3], keepdim=keepdim) |
|
|
|
|
| def _normalize_tensor(in_feat, eps= 1e-8): |
| norm_factor = torch.sqrt(eps + torch.sum(in_feat**2, dim=1, keepdim=True)) |
| return in_feat / norm_factor |
|
|
|
|
| class ScalingLayer(nn.Module): |
| def __init__(self): |
| super(ScalingLayer, self).__init__() |
| |
| |
| |
|
|
| self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) |
| self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) |
|
|
| def forward(self, inp): |
| return (inp - self.shift) / self.scale |
|
|
|
|
| class NetLinLayer(nn.Module): |
| ''' A single linear layer which does a 1x1 conv ''' |
| def __init__(self, chn_in, chn_out=1, use_dropout=False): |
| super(NetLinLayer, self).__init__() |
| layers = [nn.Dropout(), ] if (use_dropout) else [] |
| layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] |
| self.model = nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| return self.model(x) |
|
|
|
|
| class LPIPS(nn.Module): |
| def __init__(self, net='vgg', version='0.1', use_dropout=True): |
| super(LPIPS, self).__init__() |
| self.version = version |
| self.scaling_layer = ScalingLayer() |
| self.chns = [64, 128, 256, 512, 512] |
| self.L = len(self.chns) |
| self.net = vgg16() |
| self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) |
| self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) |
| self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) |
| self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) |
| self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) |
| self.lins = nn.ModuleList([self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]) |
|
|
| |
| |
| |
| |
| weights_url = f"https://github.com/akuresonite/PerceptualSimilarity-Forked/raw/master/lpips/weights/v{version}/{net}.pth" |
| |
| |
| |
| |
| state_dict = torch.hub.load_state_dict_from_url(weights_url, map_location='cpu') |
| self.load_state_dict(state_dict, strict=False) |
| |
| self.eval() |
| for param in self.parameters(): |
| param.requires_grad = False |
|
|
| def forward(self, in0, in1, normalize=False): |
| |
| if normalize: |
| in0 = 2 * in0 - 1 |
| in1 = 2 * in1 - 1 |
|
|
| in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1) |
| |
| |
| outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) |
| |
| diffs = {} |
| for kk in range(self.L): |
| feats0 = _normalize_tensor(outs0[kk]) |
| feats1 = _normalize_tensor(outs1[kk]) |
| diffs[kk] = (feats0 - feats1) ** 2 |
| |
| res = [_spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] |
| val = sum(res) |
| return val.reshape(-1) |
|
|
|
|
| |
| |
| |
| class Discriminator(nn.Module): |
| r""" |
| PatchGAN Discriminator. |
| Rather than taking IMG_CHANNELSxIMG_HxIMG_W all the way to |
| 1 scalar value , we instead predict grid of values. |
| Where each grid is prediction of how likely |
| the discriminator thinks that the image patch corresponding |
| to the grid cell is real |
| """ |
|
|
| def __init__( |
| self, |
| im_channels=3, |
| conv_channels=[64, 128, 256], |
| kernels=[4, 4, 4, 4], |
| strides=[2, 2, 2, 1], |
| paddings=[1, 1, 1, 1], |
| ): |
| super().__init__() |
| self.im_channels = im_channels |
| activation = nn.LeakyReLU(0.2) |
| layers_dim = [self.im_channels] + conv_channels + [1] |
| self.layers = nn.ModuleList( |
| [ |
| nn.Sequential( |
| nn.Conv2d( |
| layers_dim[i], |
| layers_dim[i + 1], |
| kernel_size=kernels[i], |
| stride=strides[i], |
| padding=paddings[i], |
| bias=False if i != 0 else True, |
| ), |
| ( |
| nn.BatchNorm2d(layers_dim[i + 1]) |
| if i != len(layers_dim) - 2 and i != 0 |
| else nn.Identity() |
| ), |
| activation if i != len(layers_dim) - 2 else nn.Identity(), |
| ) |
| for i in range(len(layers_dim) - 1) |
| ] |
| ) |
|
|
| def forward(self, x): |
| out = x |
| for layer in self.layers: |
| out = layer(out) |
| return out |
|
|
|
|
|
|
| |
| |
| |
| class DownBlock(nn.Module): |
| r""" |
| Down conv block with attention. |
| Sequence of following block |
| 1. Resnet block with time embedding |
| 2. Attention block |
| 3. Downsample |
| """ |
|
|
| def __init__( |
| self, |
| in_channels, |
| out_channels, |
| t_emb_dim, |
| down_sample, |
| num_heads, |
| num_layers, |
| attn, |
| norm_channels, |
| cross_attn=False, |
| context_dim=None, |
| ): |
| super().__init__() |
| self.num_layers = num_layers |
| self.down_sample = down_sample |
| self.attn = attn |
| self.context_dim = context_dim |
| self.cross_attn = cross_attn |
| self.t_emb_dim = t_emb_dim |
| self.resnet_conv_first = nn.ModuleList( |
| [ |
| nn.Sequential( |
| nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), |
| nn.SiLU(), |
| nn.Conv2d( |
| in_channels if i == 0 else out_channels, |
| out_channels, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| ) |
| for i in range(num_layers) |
| ] |
| ) |
| if self.t_emb_dim is not None: |
| self.t_emb_layers = nn.ModuleList( |
| [ |
| nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, out_channels)) |
| for _ in range(num_layers) |
| ] |
| ) |
| self.resnet_conv_second = nn.ModuleList( |
| [ |
| nn.Sequential( |
| nn.GroupNorm(norm_channels, out_channels), |
| nn.SiLU(), |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| if self.attn: |
| self.attention_norms = nn.ModuleList( |
| [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] |
| ) |
|
|
| self.attentions = nn.ModuleList( |
| [ |
| nn.MultiheadAttention(out_channels, num_heads, batch_first=True) |
| for _ in range(num_layers) |
| ] |
| ) |
| if self.cross_attn: |
| assert context_dim is not None, "Context Dimension must be passed for cross attention" |
| self.cross_attention_norms = nn.ModuleList( |
| [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] |
| ) |
| self.cross_attentions = nn.ModuleList( |
| [ |
| nn.MultiheadAttention(out_channels, num_heads, batch_first=True) |
| for _ in range(num_layers) |
| ] |
| ) |
| self.context_proj = nn.ModuleList( |
| [nn.Linear(context_dim, out_channels) for _ in range(num_layers)] |
| ) |
| self.residual_input_conv = nn.ModuleList( |
| [ |
| nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) |
| for i in range(num_layers) |
| ] |
| ) |
| self.down_sample_conv = ( |
| nn.Conv2d(out_channels, out_channels, 4, 2, 1) if self.down_sample else nn.Identity() |
| ) |
|
|
| def forward(self, x, t_emb=None, context=None): |
| out = x |
| for i in range(self.num_layers): |
| |
|
|
| resnet_input = out |
| out = self.resnet_conv_first[i](out) |
| if self.t_emb_dim is not None: |
| out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] |
| out = self.resnet_conv_second[i](out) |
| out = out + self.residual_input_conv[i](resnet_input) |
|
|
| if self.attn: |
| |
|
|
| batch_size, channels, h, w = out.shape |
| in_attn = out.reshape(batch_size, channels, h * w) |
| in_attn = self.attention_norms[i](in_attn) |
| in_attn = in_attn.transpose(1, 2) |
| out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
| out = out + out_attn |
| if self.cross_attn: |
| assert ( |
| context is not None |
| ), "context cannot be None if cross attention layers are used" |
| batch_size, channels, h, w = out.shape |
| in_attn = out.reshape(batch_size, channels, h * w) |
| in_attn = self.cross_attention_norms[i](in_attn) |
| in_attn = in_attn.transpose(1, 2) |
| assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim |
| context_proj = self.context_proj[i](context) |
| out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
| out = out + out_attn |
| |
|
|
| out = self.down_sample_conv(out) |
| return out |
|
|
|
|
|
|
| |
| |
| |
| class MidBlock(nn.Module): |
| r""" |
| Mid conv block with attention. |
| Sequence of following blocks |
| 1. Resnet block with time embedding |
| 2. Attention block |
| 3. Resnet block with time embedding |
| """ |
|
|
| def __init__( |
| self, |
| in_channels, |
| out_channels, |
| t_emb_dim, |
| num_heads, |
| num_layers, |
| norm_channels, |
| cross_attn=None, |
| context_dim=None, |
| ): |
| super().__init__() |
| self.num_layers = num_layers |
| self.t_emb_dim = t_emb_dim |
| self.context_dim = context_dim |
| self.cross_attn = cross_attn |
| self.resnet_conv_first = nn.ModuleList( |
| [ |
| nn.Sequential( |
| nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), |
| nn.SiLU(), |
| nn.Conv2d( |
| in_channels if i == 0 else out_channels, |
| out_channels, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| ) |
| for i in range(num_layers + 1) |
| ] |
| ) |
|
|
| if self.t_emb_dim is not None: |
| self.t_emb_layers = nn.ModuleList( |
| [ |
| nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels)) |
| for _ in range(num_layers + 1) |
| ] |
| ) |
| self.resnet_conv_second = nn.ModuleList( |
| [ |
| nn.Sequential( |
| nn.GroupNorm(norm_channels, out_channels), |
| nn.SiLU(), |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), |
| ) |
| for _ in range(num_layers + 1) |
| ] |
| ) |
|
|
| self.attention_norms = nn.ModuleList( |
| [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] |
| ) |
|
|
| self.attentions = nn.ModuleList( |
| [ |
| nn.MultiheadAttention(out_channels, num_heads, batch_first=True) |
| for _ in range(num_layers) |
| ] |
| ) |
| if self.cross_attn: |
| assert context_dim is not None, "Context Dimension must be passed for cross attention" |
| self.cross_attention_norms = nn.ModuleList( |
| [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] |
| ) |
| self.cross_attentions = nn.ModuleList( |
| [ |
| nn.MultiheadAttention(out_channels, num_heads, batch_first=True) |
| for _ in range(num_layers) |
| ] |
| ) |
| self.context_proj = nn.ModuleList( |
| [nn.Linear(context_dim, out_channels) for _ in range(num_layers)] |
| ) |
| self.residual_input_conv = nn.ModuleList( |
| [ |
| nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) |
| for i in range(num_layers + 1) |
| ] |
| ) |
|
|
| def forward(self, x, t_emb=None, context=None): |
| out = x |
|
|
| |
| resnet_input = out |
| out = self.resnet_conv_first[0](out) |
| if self.t_emb_dim is not None: |
| out = out + self.t_emb_layers[0](t_emb)[:, :, None, None] |
| out = self.resnet_conv_second[0](out) |
| out = out + self.residual_input_conv[0](resnet_input) |
|
|
| for i in range(self.num_layers): |
| |
| batch_size, channels, h, w = out.shape |
| in_attn = out.reshape(batch_size, channels, h * w) |
| in_attn = self.attention_norms[i](in_attn) |
| in_attn = in_attn.transpose(1, 2) |
| out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
| out = out + out_attn |
|
|
| if self.cross_attn: |
| assert ( |
| context is not None |
| ), "context cannot be None if cross attention layers are used" |
| batch_size, channels, h, w = out.shape |
| in_attn = out.reshape(batch_size, channels, h * w) |
| in_attn = self.cross_attention_norms[i](in_attn) |
| in_attn = in_attn.transpose(1, 2) |
| assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim |
| context_proj = self.context_proj[i](context) |
| out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
| out = out + out_attn |
| |
| |
| resnet_input = out |
| out = self.resnet_conv_first[i + 1](out) |
| if self.t_emb_dim is not None: |
| out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None] |
| out = self.resnet_conv_second[i + 1](out) |
| out = out + self.residual_input_conv[i + 1](resnet_input) |
| return out |
|
|
|
|
| |
| |
| |
| class UpBlock(nn.Module): |
| r""" |
| Up conv block with attention. |
| Sequence of following blocks |
| 1. Upsample |
| 1. Concatenate Down block output |
| 2. Resnet block with time embedding |
| 3. Attention Block |
| """ |
|
|
| def __init__( |
| self, |
| in_channels, |
| out_channels, |
| t_emb_dim, |
| up_sample, |
| num_heads, |
| num_layers, |
| attn, |
| norm_channels, |
| ): |
| super().__init__() |
| self.num_layers = num_layers |
| self.up_sample = up_sample |
| self.t_emb_dim = t_emb_dim |
| self.attn = attn |
| self.resnet_conv_first = nn.ModuleList( |
| [ |
| nn.Sequential( |
| nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), |
| nn.SiLU(), |
| nn.Conv2d( |
| in_channels if i == 0 else out_channels, |
| out_channels, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| ), |
| ) |
| for i in range(num_layers) |
| ] |
| ) |
|
|
| if self.t_emb_dim is not None: |
| self.t_emb_layers = nn.ModuleList( |
| [ |
| nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels)) |
| for _ in range(num_layers) |
| ] |
| ) |
| self.resnet_conv_second = nn.ModuleList( |
| [ |
| nn.Sequential( |
| nn.GroupNorm(norm_channels, out_channels), |
| nn.SiLU(), |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
| if self.attn: |
| self.attention_norms = nn.ModuleList( |
| [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] |
| ) |
|
|
| self.attentions = nn.ModuleList( |
| [ |
| nn.MultiheadAttention(out_channels, num_heads, batch_first=True) |
| for _ in range(num_layers) |
| ] |
| ) |
| self.residual_input_conv = nn.ModuleList( |
| [ |
| nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) |
| for i in range(num_layers) |
| ] |
| ) |
| self.up_sample_conv = ( |
| nn.ConvTranspose2d(in_channels, in_channels, 4, 2, 1) |
| if self.up_sample |
| else nn.Identity() |
| ) |
|
|
| def forward(self, x, out_down=None, t_emb=None): |
| |
|
|
| x = self.up_sample_conv(x) |
|
|
| |
|
|
| if out_down is not None: |
| x = torch.cat([x, out_down], dim=1) |
| out = x |
| for i in range(self.num_layers): |
| |
|
|
| resnet_input = out |
| out = self.resnet_conv_first[i](out) |
| if self.t_emb_dim is not None: |
| out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] |
| out = self.resnet_conv_second[i](out) |
| out = out + self.residual_input_conv[i](resnet_input) |
|
|
| |
|
|
| if self.attn: |
| batch_size, channels, h, w = out.shape |
| in_attn = out.reshape(batch_size, channels, h * w) |
| in_attn = self.attention_norms[i](in_attn) |
| in_attn = in_attn.transpose(1, 2) |
| out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
| out = out + out_attn |
| return out |
|
|
|
|
| |
| |
| |
| class VQVAE(nn.Module): |
| def __init__(self, im_channels, model_config): |
| super().__init__() |
| self.down_channels = model_config.down_channels |
| self.mid_channels = model_config.mid_channels |
| self.down_sample = model_config.down_sample |
| self.num_down_layers = model_config.num_down_layers |
| self.num_mid_layers = model_config.num_mid_layers |
| self.num_up_layers = model_config.num_up_layers |
|
|
| |
| self.attns = model_config.attn_down |
|
|
| |
| self.z_channels = model_config.z_channels |
| self.codebook_size = model_config.codebook_size |
| self.norm_channels = model_config.norm_channels |
| self.num_heads = model_config.num_heads |
|
|
| |
| assert self.mid_channels[0] == self.down_channels[-1] |
| assert self.mid_channels[-1] == self.down_channels[-1] |
| assert len(self.down_sample) == len(self.down_channels) - 1 |
| assert len(self.attns) == len(self.down_channels) - 1 |
|
|
| |
| |
| self.up_sample = list(reversed(self.down_sample)) |
|
|
| |
| self.encoder_conv_in = nn.Conv2d( |
| im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1) |
| ) |
|
|
| |
| self.encoder_layers = nn.ModuleList([]) |
| for i in range(len(self.down_channels) - 1): |
| self.encoder_layers.append( |
| DownBlock( |
| self.down_channels[i], |
| self.down_channels[i + 1], |
| t_emb_dim=None, |
| down_sample=self.down_sample[i], |
| num_heads=self.num_heads, |
| num_layers=self.num_down_layers, |
| attn=self.attns[i], |
| norm_channels=self.norm_channels, |
| ) |
| ) |
| self.encoder_mids = nn.ModuleList([]) |
| for i in range(len(self.mid_channels) - 1): |
| self.encoder_mids.append( |
| MidBlock( |
| self.mid_channels[i], |
| self.mid_channels[i + 1], |
| t_emb_dim=None, |
| num_heads=self.num_heads, |
| num_layers=self.num_mid_layers, |
| norm_channels=self.norm_channels, |
| ) |
| ) |
| self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1]) |
| self.encoder_conv_out = nn.Conv2d( |
| self.down_channels[-1], self.z_channels, kernel_size=3, padding=1 |
| ) |
|
|
| |
| self.pre_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) |
|
|
| |
| self.embedding = nn.Embedding(self.codebook_size, self.z_channels) |
| |
|
|
| |
|
|
| |
| self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) |
| self.decoder_conv_in = nn.Conv2d( |
| self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1) |
| ) |
|
|
| |
| self.decoder_mids = nn.ModuleList([]) |
| for i in reversed(range(1, len(self.mid_channels))): |
| self.decoder_mids.append( |
| MidBlock( |
| self.mid_channels[i], |
| self.mid_channels[i - 1], |
| t_emb_dim=None, |
| num_heads=self.num_heads, |
| num_layers=self.num_mid_layers, |
| norm_channels=self.norm_channels, |
| ) |
| ) |
| self.decoder_layers = nn.ModuleList([]) |
| for i in reversed(range(1, len(self.down_channels))): |
| self.decoder_layers.append( |
| UpBlock( |
| self.down_channels[i], |
| self.down_channels[i - 1], |
| t_emb_dim=None, |
| up_sample=self.down_sample[i - 1], |
| num_heads=self.num_heads, |
| num_layers=self.num_up_layers, |
| attn=self.attns[i - 1], |
| norm_channels=self.norm_channels, |
| ) |
| ) |
| self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0]) |
| self.decoder_conv_out = nn.Conv2d( |
| self.down_channels[0], im_channels, kernel_size=3, padding=1 |
| ) |
|
|
| def quantize(self, x): |
| B, C, H, W = x.shape |
|
|
| |
| x = x.permute(0, 2, 3, 1) |
|
|
| |
| x = x.reshape(x.size(0), -1, x.size(-1)) |
|
|
| |
| |
| dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1))) |
| |
| min_encoding_indices = torch.argmin(dist, dim=-1) |
|
|
| |
| |
| quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1)) |
|
|
| |
| x = x.reshape((-1, x.size(-1))) |
| commmitment_loss = torch.mean((quant_out.detach() - x) ** 2) |
| codebook_loss = torch.mean((quant_out - x.detach()) ** 2) |
| quantize_losses = {"codebook_loss": codebook_loss, "commitment_loss": commmitment_loss} |
| |
| quant_out = x + (quant_out - x).detach() |
|
|
| |
| quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2) |
| min_encoding_indices = min_encoding_indices.reshape( |
| (-1, quant_out.size(-2), quant_out.size(-1)) |
| ) |
| return quant_out, quantize_losses, min_encoding_indices |
|
|
| def encode(self, x): |
| out = self.encoder_conv_in(x) |
| for idx, down in enumerate(self.encoder_layers): |
| out = down(out) |
| for mid in self.encoder_mids: |
| out = mid(out) |
| out = self.encoder_norm_out(out) |
| out = nn.SiLU()(out) |
| out = self.encoder_conv_out(out) |
| out = self.pre_quant_conv(out) |
| out, quant_losses, _ = self.quantize(out) |
| return out, quant_losses |
|
|
| def decode(self, z): |
| out = z |
| out = self.post_quant_conv(out) |
| out = self.decoder_conv_in(out) |
| for mid in self.decoder_mids: |
| out = mid(out) |
| for idx, up in enumerate(self.decoder_layers): |
| out = up(out) |
| out = self.decoder_norm_out(out) |
| out = nn.SiLU()(out) |
| out = self.decoder_conv_out(out) |
| return out |
|
|
| def forward(self, x): |
| '''out: [B, 3, 256, 256] |
| z: [B, 3, 64, 64] |
| quant_losses: { |
| codebook_loss: 0.0681, |
| commitment_loss: 0.0681 |
| } |
| ''' |
| z, quant_losses = self.encode(x) |
| out = self.decode(z) |
| return out, z, quant_losses |
|
|
|
|
| |
| |
| |
| import pprint |
| config_path = "/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/config-LDM-High-Pre.yaml" |
| |
| with open(config_path, 'r') as file: |
| Config = yaml.safe_load(file) |
| pprint.pprint(Config, width=120) |
|
|
| Config = DotDict.from_dict(Config) |
| dataset_config = Config.dataset_params |
| diffusion_config = Config.diffusion_params |
| model_config = Config.model_params |
| train_config = Config.train_params |
| paths = Config.paths |
|
|
|
|
| |
| |
| |
| IMAGES_PATH = paths.images_dir |
|
|
| def walkDIR(folder_path, include=None): |
| file_list = [] |
| for root, _, files in os.walk(folder_path): |
| for file in files: |
| if include is None or any(file.endswith(ext) for ext in include): |
| file_list.append(os.path.join(root, file)) |
| print("Files found:", len(file_list)) |
| return file_list |
|
|
| files = walkDIR(IMAGES_PATH, include=['.png', '.jpeg', '.jpg']) |
| df = pd.DataFrame(files, columns=['image_path']) |
|
|
| class VaaniDataset(torch.utils.data.Dataset): |
| def __init__(self, files_paths, im_size): |
| self.files_paths = files_paths |
| self.im_size = im_size |
|
|
| def __len__(self): |
| return len(self.files_paths) |
|
|
| def __getitem__(self, idx): |
| image = tv.io.read_image(self.files_paths[idx], mode=tv.io.ImageReadMode.RGB) |
| |
| image = v2.Resize((self.im_size,self.im_size))(image) |
| image = v2.ToDtype(torch.float32, scale=True)(image) |
| |
| return image |
|
|
| dataset = VaaniDataset(files_paths=files, im_size=dataset_config.im_size) |
| image = dataset[2] |
| print('IMAGE SHAPE:', image.shape) |
| if train_config.debug: |
| s = 0.001 |
| dataset, _ = torch.utils.data.random_split(dataset, [s, 1-s], torch.manual_seed(42)) |
| print("Length of Train dataset:", len(dataset)) |
|
|
| if sys.argv[1] == "train_vae": |
| BATCH_SIZE = train_config.autoencoder_batch_size |
| elif sys.argv[1] == "train_ldm": |
| BATCH_SIZE = train_config.ldm_batch_size |
|
|
| dataloader = torch.utils.data.DataLoader( |
| dataset, |
| batch_size=BATCH_SIZE, |
| shuffle=True, |
| num_workers=48, |
| pin_memory=True, |
| drop_last=True, |
| persistent_workers=True |
| ) |
|
|
| images = next(iter(dataloader)) |
| print('BATCH SHAPE:', images.shape) |
|
|
|
|
| |
| |
| |
| dataset_config = Config.dataset_params |
| autoencoder_config = Config.autoencoder_params |
| train_config = Config.train_params |
|
|
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| import time |
|
|
| def format_time(t1, t2): |
| elapsed_time = t2 - t1 |
| if elapsed_time < 60: |
| return f"{elapsed_time:.2f} seconds" |
| elif elapsed_time < 3600: |
| minutes = elapsed_time // 60 |
| seconds = elapsed_time % 60 |
| return f"{minutes:.0f} minutes {seconds:.2f} seconds" |
| elif elapsed_time < 86400: |
| hours = elapsed_time // 3600 |
| remainder = elapsed_time % 3600 |
| minutes = remainder // 60 |
| seconds = remainder % 60 |
| return f"{hours:.0f} hours {minutes:.0f} minutes {seconds:.2f} seconds" |
| else: |
| days = elapsed_time // 86400 |
| remainder = elapsed_time % 86400 |
| hours = remainder // 3600 |
| remainder = remainder % 3600 |
| minutes = remainder // 60 |
| seconds = remainder % 60 |
| return f"{days:.0f} days {hours:.0f} hours {minutes:.0f} minutes {seconds:.2f} seconds" |
|
|
|
|
| def find_checkpoints(checkpoint_path): |
| directory = os.path.dirname(checkpoint_path) |
| prefix = os.path.basename(checkpoint_path) |
| pattern = re.compile(rf"{re.escape(prefix)}_epoch(\d+)\.pt$") |
|
|
| try: |
| files = os.listdir(directory) |
| except FileNotFoundError: |
| return [] |
|
|
| return [ |
| os.path.join(directory, f) |
| for f in files if pattern.match(f) |
| ] |
| |
| |
| def save_vae_checkpoint( |
| total_steps, epoch, model, discriminator, optimizer_d, |
| optimizer_g, metrics, checkpoint_path, logs, total_training_time |
| ): |
| checkpoint = { |
| "total_steps": total_steps, |
| "epoch": epoch, |
| "model_state_dict": model.state_dict(), |
| "discriminator_state_dict": discriminator.state_dict(), |
| "optimizer_d_state_dict": optimizer_d.state_dict(), |
| "optimizer_g_state_dict": optimizer_g.state_dict(), |
| "metrics": metrics, |
| "logs": logs, |
| "total_training_time": total_training_time |
| } |
| checkpoint_file = f"{checkpoint_path}_epoch{epoch}.pt" |
| torch.save(checkpoint, checkpoint_file) |
| print(f"VQVAE Checkpoint saved at {checkpoint_file}") |
| all_ckpts = glob.glob(f"{checkpoint_path}_epoch*.pt") |
| |
| def extract_epoch(filename): |
| match = re.search(r"_epoch(\d+)\.pt", filename) |
| return int(match.group(1)) if match else -1 |
| all_ckpts = sorted(all_ckpts, key=extract_epoch) |
| for old_ckpt in all_ckpts[:-2]: |
| os.remove(old_ckpt) |
| print(f"Removed old VQVAE checkpoint: {old_ckpt}") |
|
|
|
|
|
|
| def load_vae_checkpoint(checkpoint_path, model, discriminator, optimizer_d, optimizer_g, device=device): |
| all_ckpts = glob.glob(f"{checkpoint_path}_epoch*.pt") |
| |
| if not all_ckpts: |
| print("No VQVAE checkpoint found. Starting from scratch.") |
| return 0, 0, None, [], 0 |
| def extract_epoch(filename): |
| match = re.search(r"_epoch(\d+)\.pt", filename) |
| return int(match.group(1)) if match else -1 |
| all_ckpts = sorted(all_ckpts, key=extract_epoch) |
| latest_ckpt = all_ckpts[-1] |
| if os.path.exists(latest_ckpt): |
| checkpoint = torch.load(latest_ckpt, map_location=device) |
| model.load_state_dict(checkpoint["model_state_dict"]) |
| discriminator.load_state_dict(checkpoint["discriminator_state_dict"]) |
| optimizer_d.load_state_dict(checkpoint["optimizer_d_state_dict"]) |
| optimizer_g.load_state_dict(checkpoint["optimizer_g_state_dict"]) |
| total_steps = checkpoint["total_steps"] |
| epoch = checkpoint["epoch"] |
| metrics = checkpoint["metrics"] |
| logs = checkpoint.get("logs", []) |
| total_training_time = checkpoint.get("total_training_time", 0) |
| print(f"VQVAE Checkpoint loaded from {latest_ckpt}. Resuming from epoch {epoch + 1}, step {total_steps}") |
| return total_steps, epoch + 1, metrics, logs, total_training_time |
| else: |
| print("No VQVAE checkpoint found. Starting from scratch.") |
| return 0, 0, None, [], 0 |
|
|
| from PIL import Image |
| def inference(model, dataset, save_path, epoch, device="cuda", sample_size=8): |
| if not os.path.exists(save_path): |
| os.makedirs(save_path) |
|
|
| image_tensors = [] |
| for i in range(sample_size): |
| image_tensors.append(dataset[i].unsqueeze(0)) |
|
|
| image_tensors = torch.cat(image_tensors, dim=0).to(device) |
| with torch.no_grad(): |
| outputs, _, _ = model(image_tensors) |
|
|
| save_input = image_tensors.detach().cpu() |
| save_output = outputs.detach().cpu() |
|
|
| grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size) |
|
|
| np_img = (grid * 255).byte().numpy().transpose(1, 2, 0) |
| combined_image = Image.fromarray(np_img) |
| combined_image.save("output_image.png") |
| |
| |
| combined_image.save(os.path.join(save_path, f"reconstructed_images_EP-{epoch}_{sample_size}.png")) |
|
|
| print(f"Reconstructed images saved at: {save_path}") |
|
|
|
|
| def trainVAE(Config, dataloader): |
| dataset_config = Config.dataset_params |
| autoencoder_config = Config.autoencoder_params |
| train_config = Config.train_params |
| paths = Config.paths |
|
|
| seed = train_config.seed |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
| if device == "cuda": |
| torch.cuda.manual_seed_all(seed) |
|
|
| model = VQVAE(im_channels=dataset_config.im_channels, model_config=autoencoder_config).to(device) |
| discriminator = Discriminator(im_channels=dataset_config.im_channels).to(device) |
|
|
| optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999)) |
| optimizer_g = torch.optim.AdamW(model.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999)) |
|
|
| checkpoint_path = os.path.join(train_config.task_name, "vqvae_ckpt.pth") |
| (total_steps, start_epoch, |
| metrics, logs, total_training_time) = load_vae_checkpoint(checkpoint_path, |
| model, discriminator, |
| optimizer_d, optimizer_g) |
|
|
| if not os.path.exists(train_config.task_name): |
| os.mkdir(train_config.task_name) |
|
|
| num_epochs = train_config.autoencoder_epochs |
| recon_criterion = torch.nn.MSELoss() |
| disc_criterion = torch.nn.MSELoss() |
| lpips_model = LPIPS().eval().to(device) |
|
|
| acc_steps = train_config.autoencoder_acc_steps |
| disc_step_start = train_config.disc_start |
|
|
| start_time_total = time.time() - total_training_time |
|
|
| for epoch_idx in trange(start_epoch, num_epochs, colour='red', dynamic_ncols=True): |
| start_time_epoch = time.time() |
| epoch_log = [] |
|
|
| for images in tqdm(dataloader, colour='green', dynamic_ncols=True): |
| batch_start_time = time.time() |
| total_steps += 1 |
|
|
| images = images.to(device) |
| model_output = model(images) |
| output, z, quantize_losses = model_output |
|
|
| recon_loss = recon_criterion(output, images) / acc_steps |
|
|
| g_loss = ( |
| recon_loss |
| + (train_config.codebook_weight * quantize_losses["codebook_loss"] / acc_steps) |
| + (train_config.commitment_beta * quantize_losses["commitment_loss"] / acc_steps) |
| ) |
|
|
| if total_steps > disc_step_start: |
| disc_fake_pred = discriminator(output) |
| disc_fake_loss = disc_criterion(disc_fake_pred, torch.ones_like(disc_fake_pred)) |
| g_loss += train_config.disc_weight * disc_fake_loss / acc_steps |
|
|
| lpips_loss = torch.mean(lpips_model(output, images)) / acc_steps |
| g_loss += train_config.perceptual_weight * lpips_loss |
|
|
| g_loss.backward() |
|
|
| if total_steps % acc_steps == 0: |
| optimizer_g.step() |
| optimizer_g.zero_grad() |
|
|
| if total_steps > disc_step_start: |
| disc_fake_pred = discriminator(output.detach()) |
| disc_real_pred = discriminator(images) |
| |
| |
| |
| |
| disc_fake_loss = disc_criterion(disc_fake_pred, torch.zeros(disc_fake_pred.shape, device=disc_fake_pred.device)) |
| disc_real_loss = disc_criterion(disc_real_pred, torch.ones(disc_real_pred.shape, device=disc_real_pred.device)) |
| |
| disc_loss = train_config.disc_weight * (disc_fake_loss + disc_real_loss) / 2 / acc_steps |
| |
| disc_loss.backward() |
|
|
| if total_steps % acc_steps == 0: |
| optimizer_d.step() |
| optimizer_d.zero_grad() |
| |
| if total_steps % acc_steps == 0: |
| optimizer_g.step() |
| optimizer_g.zero_grad() |
|
|
| batch_time = time.time() - batch_start_time |
| epoch_log.append(format_time(0, batch_time)) |
|
|
| optimizer_d.step() |
| optimizer_d.zero_grad() |
| optimizer_g.step() |
| optimizer_g.zero_grad() |
| |
| epoch_time = time.time() - start_time_epoch |
| logs.append({"epoch": epoch_idx + 1, "epoch_time": format_time(0, epoch_time), "batch_times": epoch_log}) |
|
|
| total_training_time = time.time() - start_time_total |
|
|
| save_vae_checkpoint(total_steps, epoch_idx + 1, model, discriminator, optimizer_d, optimizer_g, metrics, checkpoint_path, logs, total_training_time) |
| recon_save_path = os.path.join(train_config.task_name, 'vqvae_recon') |
| inference(model, dataset, recon_save_path, epoch=epoch_idx, device=device, sample_size=16) |
|
|
| print("Training completed.") |
|
|
|
|
|
|
|
|
| |
| |
| |
| |
|
|
| |
| |
|
|
|
|
|
|
|
|
|
|
| |
| |
| |
| class LinearNoiseScheduler: |
| r""" |
| Class for the linear noise scheduler that is used in DDPM. |
| """ |
|
|
| def __init__(self, num_timesteps, beta_start, beta_end): |
| |
| self.num_timesteps = num_timesteps |
| self.beta_start = beta_start |
| self.beta_end = beta_end |
| |
| self.betas = ( |
| torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_timesteps) ** 2 |
| ) |
| self.alphas = 1. - self.betas |
| self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0) |
| self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod) |
| self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod) |
|
|
| def add_noise(self, original, noise, t): |
| r""" |
| Forward method for diffusion |
| :param original: Image on which noise is to be applied |
| :param noise: Random Noise Tensor (from normal dist) |
| :param t: timestep of the forward process of shape -> (B,) |
| :return: |
| """ |
| original_shape = original.shape |
| batch_size = original_shape[0] |
|
|
| sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size) |
| sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size) |
|
|
| |
| for _ in range(len(original_shape) - 1): |
| sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1) |
| for _ in range(len(original_shape) - 1): |
| sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1) |
|
|
| |
| return (sqrt_alpha_cum_prod.to(original.device) * original |
| + sqrt_one_minus_alpha_cum_prod.to(original.device) * noise) |
|
|
| def sample_prev_timestep(self, xt, noise_pred, t): |
| r""" |
| Use the noise prediction by model to get |
| xt-1 using xt and the nosie predicted |
| :param xt: current timestep sample |
| :param noise_pred: model noise prediction |
| :param t: current timestep we are at |
| :return: |
| """ |
| x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) / |
| torch.sqrt(self.alpha_cum_prod.to(xt.device)[t])) |
| x0 = torch.clamp(x0, -1., 1.) |
|
|
| mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t]) |
| mean = mean / torch.sqrt(self.alphas.to(xt.device)[t]) |
|
|
| if t == 0: |
| return mean, x0 |
| else: |
| variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t]) |
| variance = variance * self.betas.to(xt.device)[t] |
| sigma = variance ** 0.5 |
| z = torch.randn(xt.shape).to(xt.device) |
|
|
| |
| |
| |
| |
| return mean + sigma * z, x0 |
|
|
|
|
|
|
| |
| |
| |
| def get_time_embedding(time_steps, temb_dim): |
| r""" |
| Convert time steps tensor into an embedding using the |
| sinusoidal time embedding formula |
| :param time_steps: 1D tensor of length batch size |
| :param temb_dim: Dimension of the embedding |
| :return: BxD embedding representation of B time steps |
| """ |
| assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" |
|
|
| |
| factor = 10000 ** ((torch.arange( |
| start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2)) |
| ) |
|
|
| |
| |
| t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor |
| t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) |
| return t_emb |
|
|
|
|
|
|
|
|
|
|
| |
| |
| |
| class UpBlockUnet(nn.Module): |
| r""" |
| Up conv block with attention. |
| Sequence of following blocks |
| 1. Upsample |
| 1. Concatenate Down block output |
| 2. Resnet block with time embedding |
| 3. Attention Block |
| """ |
|
|
| def __init__(self, in_channels, out_channels, t_emb_dim, up_sample, |
| num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None): |
| super().__init__() |
| self.num_layers = num_layers |
| self.up_sample = up_sample |
| self.t_emb_dim = t_emb_dim |
| self.cross_attn = cross_attn |
| self.context_dim = context_dim |
| self.resnet_conv_first = nn.ModuleList( |
| [ |
| nn.Sequential( |
| nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), |
| nn.SiLU(), |
| nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, |
| padding=1), |
| ) |
| for i in range(num_layers) |
| ] |
| ) |
|
|
| if self.t_emb_dim is not None: |
| self.t_emb_layers = nn.ModuleList([ |
| nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(t_emb_dim, out_channels) |
| ) |
| for _ in range(num_layers) |
| ]) |
|
|
| self.resnet_conv_second = nn.ModuleList( |
| [ |
| nn.Sequential( |
| nn.GroupNorm(norm_channels, out_channels), |
| nn.SiLU(), |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| self.attention_norms = nn.ModuleList( |
| [ |
| nn.GroupNorm(norm_channels, out_channels) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| self.attentions = nn.ModuleList( |
| [ |
| nn.MultiheadAttention(out_channels, num_heads, batch_first=True) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| if self.cross_attn: |
| assert context_dim is not None, "Context Dimension must be passed for cross attention" |
| self.cross_attention_norms = nn.ModuleList( |
| [nn.GroupNorm(norm_channels, out_channels) |
| for _ in range(num_layers)] |
| ) |
| self.cross_attentions = nn.ModuleList( |
| [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) |
| for _ in range(num_layers)] |
| ) |
| self.context_proj = nn.ModuleList( |
| [nn.Linear(context_dim, out_channels) |
| for _ in range(num_layers)] |
| ) |
| self.residual_input_conv = nn.ModuleList( |
| [ |
| nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) |
| for i in range(num_layers) |
| ] |
| ) |
| self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, |
| 4, 2, 1) \ |
| if self.up_sample else nn.Identity() |
|
|
| def forward(self, x, out_down=None, t_emb=None, context=None): |
| x = self.up_sample_conv(x) |
| if out_down is not None: |
| x = torch.cat([x, out_down], dim=1) |
|
|
| out = x |
| for i in range(self.num_layers): |
| |
| resnet_input = out |
| out = self.resnet_conv_first[i](out) |
| if self.t_emb_dim is not None: |
| out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] |
| out = self.resnet_conv_second[i](out) |
| out = out + self.residual_input_conv[i](resnet_input) |
| |
| |
| batch_size, channels, h, w = out.shape |
| in_attn = out.reshape(batch_size, channels, h * w) |
| in_attn = self.attention_norms[i](in_attn) |
| in_attn = in_attn.transpose(1, 2) |
| out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
| out = out + out_attn |
| |
| |
| if self.cross_attn: |
| assert context is not None, "context cannot be None if cross attention layers are used" |
| batch_size, channels, h, w = out.shape |
| in_attn = out.reshape(batch_size, channels, h * w) |
| in_attn = self.cross_attention_norms[i](in_attn) |
| in_attn = in_attn.transpose(1, 2) |
| assert len(context.shape) == 3, \ |
| "Context shape does not match B,_,CONTEXT_DIM" |
| assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim,\ |
| "Context shape does not match B,_,CONTEXT_DIM" |
| context_proj = self.context_proj[i](context) |
| out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) |
| out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
| out = out + out_attn |
|
|
| return out |
|
|
|
|
|
|
| |
| |
| |
| class Unet(nn.Module): |
| r""" |
| Unet model comprising |
| Down blocks, Midblocks and Uplocks |
| """ |
|
|
| def __init__(self, im_channels, model_config): |
| super().__init__() |
| self.down_channels = model_config.down_channels |
| self.mid_channels = model_config.mid_channels |
| self.t_emb_dim = model_config.time_emb_dim |
| self.down_sample = model_config.down_sample |
| self.num_down_layers = model_config.num_down_layers |
| self.num_mid_layers = model_config.num_mid_layers |
| self.num_up_layers = model_config.num_up_layers |
| self.attns = model_config.attn_down |
| self.norm_channels = model_config.norm_channels |
| self.num_heads = model_config.num_heads |
| self.conv_out_channels = model_config.conv_out_channels |
|
|
| assert self.mid_channels[0] == self.down_channels[-1] |
| assert self.mid_channels[-1] == self.down_channels[-2] |
| assert len(self.down_sample) == len(self.down_channels) - 1 |
| assert len(self.attns) == len(self.down_channels) - 1 |
| |
| self.condition_config = model_config.condition_config |
| self.cond = condition_types = self.condition_config.condition_types |
| if 'audio' in condition_types: |
| self.audio_cond = True |
| self.audio_embed_dim = self.condition_config.audio_condition_config.audio_embed_dim |
|
|
| |
| self.t_proj = nn.Sequential( |
| nn.Linear(self.t_emb_dim, self.t_emb_dim), |
| nn.SiLU(), |
| nn.Linear(self.t_emb_dim, self.t_emb_dim), |
| ) |
| |
| |
| |
| self.context_projector = nn.Sequential( |
| nn.Linear(self.audio_embed_dim, 320), |
| nn.SiLU(), |
| nn.Linear(320, 1) |
| ) |
|
|
| self.up_sample = list(reversed(self.down_sample)) |
| self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=1) |
|
|
| |
| self.downs = nn.ModuleList([]) |
| for i in range(len(self.down_channels) - 1): |
| |
| self.downs.append( |
| DownBlock( |
| self.down_channels[i], |
| self.down_channels[i + 1], |
| self.t_emb_dim, |
| down_sample=self.down_sample[i], |
| num_heads=self.num_heads, |
| num_layers=self.num_down_layers, |
| attn=self.attns[i], |
| norm_channels=self.norm_channels, |
| cross_attn=self.audio_cond, |
| context_dim=self.audio_embed_dim |
| ) |
| ) |
|
|
| |
| self.mids = nn.ModuleList([]) |
| for i in range(len(self.mid_channels) - 1): |
| self.mids.append( |
| MidBlock( |
| self.mid_channels[i], |
| self.mid_channels[i + 1], |
| self.t_emb_dim, |
| num_heads=self.num_heads, |
| num_layers=self.num_mid_layers, |
| norm_channels=self.norm_channels, |
| cross_attn=self.audio_cond, |
| context_dim=self.audio_embed_dim |
| ) |
| ) |
|
|
| |
| self.ups = nn.ModuleList([]) |
| for i in reversed(range(len(self.down_channels) - 1)): |
| self.ups.append( |
| UpBlockUnet( |
| self.down_channels[i] * 2, |
| self.down_channels[i - 1] if i != 0 else self.conv_out_channels, |
| self.t_emb_dim, |
| up_sample=self.down_sample[i], |
| num_heads=self.num_heads, |
| num_layers=self.num_up_layers, |
| norm_channels=self.norm_channels, |
| cross_attn=self.audio_cond, |
| context_dim=self.audio_embed_dim |
| ) |
| ) |
|
|
| self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels) |
| self.conv_out = nn.Conv2d(self.conv_out_channels, im_channels, kernel_size=3, padding=1) |
|
|
| def forward(self, x, t, cond_input=None): |
| |
| |
| |
| |
| out = self.conv_in(x) |
| |
|
|
| |
| t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim) |
| t_emb = self.t_proj(t_emb) |
| |
| |
| if self.audio_cond: |
| |
| |
| |
| last_hidden_state = cond_input |
| weights = self.context_projector(last_hidden_state) |
| weights = torch.softmax(weights, dim=1) |
| pooled_embedding = (last_hidden_state * weights).sum(dim=1) |
| context_hidden_states = pooled_embedding.unsqueeze(1) |
| |
| |
| |
|
|
| |
| down_outs = [] |
| for idx, down in enumerate(self.downs): |
| down_outs.append(out) |
| out = down(out, t_emb, context_hidden_states) |
| |
| |
|
|
| |
| for mid in self.mids: |
| out = mid(out, t_emb, context_hidden_states) |
| |
|
|
| |
| for up in self.ups: |
| down_out = down_outs.pop() |
| out = up(out, down_out, t_emb, context_hidden_states) |
| |
| |
| out = self.norm_out(out) |
| out = nn.SiLU()(out) |
| out = self.conv_out(out) |
| |
| return out |
| |
| |
|
|
| |
| |
| |
| def find_checkpoints(checkpoint_path): |
| directory = os.path.dirname(checkpoint_path) |
| prefix = os.path.basename(checkpoint_path) |
| pattern = re.compile(rf"{re.escape(prefix)}_epoch(\d+)\.pt$") |
|
|
| try: |
| files = os.listdir(directory) |
| except FileNotFoundError: |
| return [] |
|
|
| return [ |
| os.path.join(directory, f) |
| for f in files if pattern.match(f) |
| ] |
|
|
| def save_ldm_checkpoint(checkpoint_path, |
| total_steps, epoch, model, optimizer, |
| metrics, logs, total_training_time |
| ): |
| checkpoint = { |
| "total_steps": total_steps, |
| "epoch": epoch, |
| "model_state_dict": model.state_dict(), |
| "optimizer_state_dict": optimizer.state_dict(), |
| "metrics": metrics, |
| "logs": logs, |
| "total_training_time": total_training_time |
| } |
| checkpoint_file = f"{checkpoint_path}_epoch{epoch}.pt" |
| torch.save(checkpoint, checkpoint_file) |
| print(f"LDM Checkpoint saved at {checkpoint_file}") |
| all_ckpts = glob.glob(f"{checkpoint_path}_epoch*.pt") |
| |
| def extract_epoch(filename): |
| match = re.search(r"_epoch(\d+)\.pt", filename) |
| return int(match.group(1)) if match else -1 |
| all_ckpts = sorted(all_ckpts, key=extract_epoch) |
| for old_ckpt in all_ckpts[:-2]: |
| os.remove(old_ckpt) |
| print(f"Removed old LDM checkpoint: {old_ckpt}") |
|
|
| def load_ldm_checkpoint(checkpoint_path, model, optimizer): |
| all_ckpts = glob.glob(f"{checkpoint_path}_epoch*.pt") |
| |
| if not all_ckpts: |
| print("No LDM checkpoint found. Starting from scratch.") |
| return 0, 0, None, [], 0 |
| def extract_epoch(filename): |
| match = re.search(r"_epoch(\d+)\.pt", filename) |
| return int(match.group(1)) if match else -1 |
| all_ckpts = sorted(all_ckpts, key=extract_epoch) |
| latest_ckpt = all_ckpts[-1] |
| if os.path.exists(latest_ckpt): |
| checkpoint = torch.load(latest_ckpt, map_location=device) |
| model.load_state_dict(checkpoint["model_state_dict"]) |
| optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) |
| total_steps = checkpoint["total_steps"] |
| epoch = checkpoint["epoch"] |
| metrics = checkpoint["metrics"] |
| logs = checkpoint.get("logs", []) |
| total_training_time = checkpoint.get("total_training_time", 0) |
| print(f"LDM Checkpoint loaded from {latest_ckpt}. Resuming from epoch {epoch + 1}, step {total_steps}") |
| return total_steps, epoch + 1, metrics, logs, total_training_time |
| else: |
| print("No LDM checkpoint found. Starting from scratch.") |
| return 0, 0, None, [], 0 |
| |
| def load_ldm_vae_checkpoint(checkpoint_path, vae, device=device): |
| |
| all_ckpts = find_checkpoints(checkpoint_path) |
| if not all_ckpts: |
| print("No VQVAE checkpoint found.") |
| return 0, 0, None, [], 0 |
| def extract_epoch(filename): |
| match = re.search(r"_epoch(\d+)\.pt", filename) |
| return int(match.group(1)) if match else -1 |
| all_ckpts = sorted(all_ckpts, key=extract_epoch) |
| latest_ckpt = all_ckpts[-1] |
| if os.path.exists(latest_ckpt): |
| checkpoint = torch.load(latest_ckpt, map_location=device) |
| vae.load_state_dict(checkpoint["model_state_dict"]) |
| total_steps = checkpoint["total_steps"] |
| epoch = checkpoint["epoch"] |
| print(f"VQVAE Checkpoint loaded from {latest_ckpt} at epoch {epoch + 1} & step {total_steps}") |
| |
| def trainLDM(Config, dataloader): |
| diffusion_config = Config.diffusion_params |
| dataset_config = Config.dataset_params |
| diffusion_model_config = Config.ldm_params |
| autoencoder_model_config = Config.autoencoder_params |
| train_config = Config.train_params |
| condition_config = diffusion_model_config.condition_config |
| |
| seed = train_config.seed |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
| if device == "cuda": |
| torch.cuda.manual_seed_all(seed) |
| |
| vqvae_device = "cuda:1" |
| ldm_device = "cuda:0" |
| |
|
|
| |
| scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config.num_timesteps, |
| beta_start=diffusion_config.beta_start, |
| beta_end=diffusion_config.beta_end) |
| |
| if not train_config.ldm_pretraining: |
| if condition_config is not None: |
| condition_types = condition_config.condition_types |
| if 'audio' in condition_types: |
| from msclap import CLAP |
| audio_model = CLAP(version = '2023', use_cuda=(True if "cuda" in device else False)) |
| |
| model = Unet(im_channels=autoencoder_model_config.z_channels, model_config=diffusion_model_config).to(device) |
| optimizer = torch.optim.AdamW(model.parameters(), lr=train_config.ldm_lr) |
| criterion = torch.nn.MSELoss() |
| num_epochs = train_config.ldm_epochs |
| |
| checkpoint_path = os.path.join(os.getcwd(), train_config.task_name, "ldmH_ckpt") |
| (total_steps, start_epoch, metrics, logs, |
| total_training_time) = load_ldm_checkpoint(checkpoint_path, model, optimizer) |
| |
| vae = VQVAE(im_channels=dataset_config.im_channels, model_config=autoencoder_model_config).eval().to(vqvae_device) |
| vae_checkpoint_path = os.path.join(os.getcwd(), train_config.task_name, "vqvae_ckpt") |
| load_ldm_vae_checkpoint(vae_checkpoint_path, vae, vqvae_device) |
| for param in vae.parameters(): |
| param.requires_grad = False |
| vae.eval() |
|
|
| if not os.path.exists(train_config.task_name): |
| os.makedirs(train_config.task_name, exist_ok=True) |
| |
| acc_steps = train_config.ldm_acc_steps |
| disc_step_start = train_config.disc_start |
| start_time_total = time.time() - total_training_time |
| |
| model.train() |
| optimizer.zero_grad() |
| for epoch_idx in trange(start_epoch, num_epochs, desc=f"{device}-LDM Epoch", colour='red', dynamic_ncols=True): |
| start_time_epoch = time.time() |
| losses = [] |
| epoch_log = [] |
| |
| |
| vae_checkpoint_path = os.path.join(os.getcwd(), train_config.task_name, "vqvae_ckpt") |
| load_ldm_vae_checkpoint(vae_checkpoint_path, vae, vqvae_device) |
| for param in vae.parameters(): |
| param.requires_grad = False |
| vae.eval() |
| |
| |
| for images in tqdm(dataloader, colour='green', dynamic_ncols=True): |
| cond_input = None |
| batch_start_time = time.time() |
| total_steps += 1 |
| batch_size = images.shape[0] |
| |
| |
| with torch.no_grad(): |
| images, _ = vae.encode(images.to(vqvae_device)) |
| images = images.to(ldm_device) |
| |
| |
| audio_embed_dim = condition_config.audio_condition_config.audio_embed_dim |
| |
| |
| empty_audio_embedding = torch.zeros((batch_size, 1500, 1280), device=device).float() |
| if not train_config.ldm_pretraining: |
| if 'audio' in condition_types: |
| with torch.no_grad(): |
| audio_embeddings = audio_model.get_audio_embeddings(cond_input) |
| text_drop_prob = condition_config.audio_condition_config.cond_drop_prob |
| text_drop_mask = torch.zeros((images.shape[0]), device=images.device).float().uniform_(0, 1) < text_drop_prob |
| audio_embeddings[text_drop_mask, :, :] = empty_audio_embedding[0] |
| else: |
| audio_embeddings = empty_audio_embedding |
| |
| |
| noise = torch.randn_like(images).to(device) |
|
|
| |
| t = torch.randint(0, diffusion_config.num_timesteps, (images.shape[0],)).to(device) |
|
|
| |
| noisy_images = scheduler.add_noise(images, noise, t) |
| noise_pred = model(noisy_images, t, cond_input=audio_embeddings) |
|
|
| loss = criterion(noise_pred, noise) |
| losses.append(loss.item()) |
| loss = loss / acc_steps |
| loss.backward() |
| |
| if total_steps % acc_steps == 0: |
| optimizer.step() |
| optimizer.zero_grad() |
| |
| if total_steps % acc_steps == 0: |
| optimizer.step() |
| optimizer.zero_grad() |
| |
| print(f'Finished epoch:{epoch_idx + 1}/{num_epochs} | Loss : {np.mean(losses):.4f}') |
|
|
| epoch_time = time.time() - start_time_epoch |
| logs.append({"epoch": epoch_idx + 1, "epoch_time": format_time(0, epoch_time), "batch_times": epoch_log}) |
|
|
| total_training_time = time.time() - start_time_total |
| save_ldm_checkpoint(checkpoint_path, total_steps, epoch_idx + 1, model, optimizer, metrics, logs, total_training_time) |
|
|
| |
| infer(Config) |
|
|
| |
| train_continue = DotDict.from_dict(yaml.safe_load(open(config_path, 'r'))) |
| if train_continue.training.continue_ldm == False: |
| print('LDM Training Stoped ...') |
| break |
|
|
| print('Done Training ...') |
|
|
|
|
|
|
| |
| |
| |
| def sample(model, scheduler, train_config, diffusion_model_config, |
| autoencoder_model_config, diffusion_config, dataset_config, |
| vae, audio_model |
| ): |
| r""" |
| Sample stepwise by going backward one timestep at a time. |
| We save the x0 predictions |
| """ |
| im_size = dataset_config.im_size // 2**sum(autoencoder_model_config.down_sample) |
| xt = torch.randn((train_config.num_samples, |
| autoencoder_model_config.z_channels, |
| im_size, |
| im_size)).to(device) |
|
|
| audio_embed_dim = diffusion_model_config.condition_config.audio_condition_config.audio_embed_dim |
| |
| |
| |
| empty_audio_embedding = torch.zeros((train_config.num_samples, 1500, 1280), device=device).float() |
| if not train_config.ldm_pretraining: |
| |
| pass |
| else: |
| audio_embeddings = empty_audio_embedding |
| |
| uncond_input = empty_audio_embedding |
| cond_input = audio_embeddings |
| |
| save_count = 0 |
| for i in tqdm(reversed(range(diffusion_config.num_timesteps)), |
| total=diffusion_config.num_timesteps, |
| colour='blue', desc="Sampling", dynamic_ncols=True): |
| |
| t = (torch.ones((xt.shape[0],)) * i).long().to(device) |
| |
| noise_pred_cond = model(xt, t, cond_input) |
| |
| cf_guidance_scale = train_config.cf_guidance_scale |
| if cf_guidance_scale > 1: |
| noise_pred_uncond = model(xt, t, uncond_input) |
| noise_pred = noise_pred_uncond + cf_guidance_scale * (noise_pred_cond - noise_pred_uncond) |
| else: |
| noise_pred = noise_pred_cond |
|
|
| |
| xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device)) |
|
|
| |
| |
| if i == 0: |
| |
| ims = vae.decode(xt) |
| else: |
| |
| ims = x0_pred |
|
|
| ims = torch.clamp(ims, -1., 1.).detach().cpu() |
| ims = (ims + 1) / 2 |
| grid = make_grid(ims, nrow=train_config.num_grid_rows) |
| |
| np_img = (grid * 255).byte().numpy().transpose(1, 2, 0) |
| img = Image.fromarray(np_img) |
|
|
| if not os.path.exists(os.path.join(train_config.task_name, 'samplesH')): |
| os.makedirs(os.path.join(train_config.task_name, 'samplesH'), exist_ok=True) |
| |
| img.save(os.path.join(train_config.task_name, 'samplesH', 'x0_{}.png'.format(i))) |
| img.close() |
|
|
|
|
| def infer(Config): |
| diffusion_config = Config.diffusion_params |
| dataset_config = Config.dataset_params |
| diffusion_model_config = Config.ldm_params |
| autoencoder_model_config = Config.autoencoder_params |
| train_config = Config.train_params |
|
|
| |
| scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config.num_timesteps, |
| beta_start=diffusion_config.beta_start, |
| beta_end=diffusion_config.beta_end) |
|
|
| model = Unet(im_channels=autoencoder_model_config.z_channels, |
| model_config=diffusion_model_config).eval().to(device) |
| vae = VQVAE(im_channels=dataset_config.im_channels, |
| model_config=autoencoder_model_config).eval().to(device) |
|
|
| if os.path.exists(os.path.join(train_config.task_name, train_config.ldm_ckpt_name)): |
| checkpoint_path = os.path.join(train_config.task_name, train_config.ldm_ckpt_name) |
| checkpoint = torch.load(checkpoint_path, map_location=device) |
| model.load_state_dict(checkpoint["model_state_dict"]) |
| vae.load_state_dict(checkpoint["vae_state_dict"]) |
| print('Loaded unet & vae checkpoint') |
| |
| |
| if not os.path.exists(train_config.task_name): |
| os.makedirs(train_config.task_name, exist_ok=True) |
| |
| with torch.no_grad(): |
| sample(model, scheduler, train_config, diffusion_model_config, |
| autoencoder_model_config, diffusion_config, dataset_config, vae, None) |
|
|
|
|
|
|
|
|
| |
| |
| |
| |
|
|
| |
|
|
| if sys.argv[1] == 'train_vae': |
| trainVAE(Config, dataloader) |
| elif sys.argv[1] == 'train_ldm': |
| trainLDM(Config, dataloader) |
| else: |
| infer(Config) |
| |
| |
|
|
| |
| |