|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "1" |
|
|
|
|
|
"""Lpips""" |
|
|
|
|
|
|
|
|
from collections import namedtuple |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.init as init |
|
|
from torch.autograd import Variable |
|
|
import numpy as np |
|
|
import torch.nn |
|
|
import torchvision |
|
|
|
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
def spatial_average(in_tens, keepdim=True): |
|
|
return in_tens.mean([2, 3], keepdim=keepdim) |
|
|
|
|
|
|
|
|
class vgg16(torch.nn.Module): |
|
|
def __init__(self, requires_grad=False, pretrained=True): |
|
|
super(vgg16, self).__init__() |
|
|
vgg_pretrained_features = torchvision.models.vgg16( |
|
|
weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1 |
|
|
).features |
|
|
self.slice1 = torch.nn.Sequential() |
|
|
self.slice2 = torch.nn.Sequential() |
|
|
self.slice3 = torch.nn.Sequential() |
|
|
self.slice4 = torch.nn.Sequential() |
|
|
self.slice5 = torch.nn.Sequential() |
|
|
self.N_slices = 5 |
|
|
for x in range(4): |
|
|
self.slice1.add_module(str(x), vgg_pretrained_features[x]) |
|
|
for x in range(4, 9): |
|
|
self.slice2.add_module(str(x), vgg_pretrained_features[x]) |
|
|
for x in range(9, 16): |
|
|
self.slice3.add_module(str(x), vgg_pretrained_features[x]) |
|
|
for x in range(16, 23): |
|
|
self.slice4.add_module(str(x), vgg_pretrained_features[x]) |
|
|
for x in range(23, 30): |
|
|
self.slice5.add_module(str(x), vgg_pretrained_features[x]) |
|
|
|
|
|
|
|
|
if not requires_grad: |
|
|
for param in self.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def forward(self, X): |
|
|
|
|
|
h = self.slice1(X) |
|
|
h_relu1_2 = h |
|
|
h = self.slice2(h) |
|
|
h_relu2_2 = h |
|
|
h = self.slice3(h) |
|
|
h_relu3_3 = h |
|
|
h = self.slice4(h) |
|
|
h_relu4_3 = h |
|
|
h = self.slice5(h) |
|
|
h_relu5_3 = h |
|
|
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) |
|
|
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class LPIPS(nn.Module): |
|
|
def __init__(self, net='vgg', version='0.1', use_dropout=True): |
|
|
super(LPIPS, self).__init__() |
|
|
self.version = version |
|
|
|
|
|
self.scaling_layer = ScalingLayer() |
|
|
|
|
|
|
|
|
|
|
|
self.chns = [64, 128, 256, 512, 512] |
|
|
self.L = len(self.chns) |
|
|
self.net = vgg16(pretrained=True, requires_grad=False) |
|
|
|
|
|
|
|
|
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) |
|
|
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) |
|
|
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) |
|
|
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) |
|
|
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) |
|
|
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] |
|
|
self.lins = nn.ModuleList(self.lins) |
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
|
import os |
|
|
|
|
|
print(os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net)))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.eval() |
|
|
for param in self.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
def forward(self, in0, in1, normalize=False): |
|
|
|
|
|
if normalize: |
|
|
in0 = 2 * in0 - 1 |
|
|
in1 = 2 * in1 - 1 |
|
|
|
|
|
|
|
|
|
|
|
in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1) |
|
|
|
|
|
|
|
|
|
|
|
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) |
|
|
feats0, feats1, diffs = {}, {}, {} |
|
|
|
|
|
|
|
|
|
|
|
for kk in range(self.L): |
|
|
feats0[kk], feats1[kk] = torch.nn.functional.normalize(outs0[kk], dim=1), torch.nn.functional.normalize( |
|
|
outs1[kk]) |
|
|
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 |
|
|
|
|
|
|
|
|
|
|
|
res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] |
|
|
val = 0 |
|
|
|
|
|
|
|
|
for l in range(self.L): |
|
|
val += res[l] |
|
|
return val |
|
|
|
|
|
|
|
|
class ScalingLayer(nn.Module): |
|
|
def __init__(self): |
|
|
super(ScalingLayer, self).__init__() |
|
|
|
|
|
|
|
|
|
|
|
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) |
|
|
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) |
|
|
|
|
|
def forward(self, inp): |
|
|
return (inp - self.shift) / self.scale |
|
|
|
|
|
|
|
|
class NetLinLayer(nn.Module): |
|
|
''' A single linear layer which does a 1x1 conv ''' |
|
|
|
|
|
def __init__(self, chn_in, chn_out=1, use_dropout=False): |
|
|
super(NetLinLayer, self).__init__() |
|
|
|
|
|
layers = [nn.Dropout(), ] if (use_dropout) else [] |
|
|
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] |
|
|
self.model = nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, x): |
|
|
out = self.model(x) |
|
|
return out |
|
|
|
|
|
"""Blocks""" |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
class LinearNoiseScheduler: |
|
|
r""" |
|
|
Class for the linear noise scheduler that is used in DDPM. |
|
|
""" |
|
|
|
|
|
def __init__(self, num_timesteps, beta_start, beta_end): |
|
|
|
|
|
self.num_timesteps = num_timesteps |
|
|
self.beta_start = beta_start |
|
|
self.beta_end = beta_end |
|
|
|
|
|
self.betas = ( |
|
|
torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_timesteps) ** 2 |
|
|
) |
|
|
self.alphas = 1. - self.betas |
|
|
self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0) |
|
|
self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod) |
|
|
self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod) |
|
|
|
|
|
def add_noise(self, original, noise, t): |
|
|
r""" |
|
|
Forward method for diffusion |
|
|
:param original: Image on which noise is to be applied |
|
|
:param noise: Random Noise Tensor (from normal dist) |
|
|
:param t: timestep of the forward process of shape -> (B,) |
|
|
:return: |
|
|
""" |
|
|
original_shape = original.shape |
|
|
batch_size = original_shape[0] |
|
|
|
|
|
sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size) |
|
|
sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size) |
|
|
|
|
|
|
|
|
for _ in range(len(original_shape) - 1): |
|
|
sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1) |
|
|
for _ in range(len(original_shape) - 1): |
|
|
sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1) |
|
|
|
|
|
|
|
|
return (sqrt_alpha_cum_prod.to(original.device) * original |
|
|
+ sqrt_one_minus_alpha_cum_prod.to(original.device) * noise) |
|
|
|
|
|
def sample_prev_timestep(self, xt, noise_pred, t): |
|
|
r""" |
|
|
Use the noise prediction by model to get |
|
|
xt-1 using xt and the nosie predicted |
|
|
:param xt: current timestep sample |
|
|
:param noise_pred: model noise prediction |
|
|
:param t: current timestep we are at |
|
|
:return: |
|
|
""" |
|
|
x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) / |
|
|
torch.sqrt(self.alpha_cum_prod.to(xt.device)[t])) |
|
|
x0 = torch.clamp(x0, -1., 1.) |
|
|
|
|
|
mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t]) |
|
|
mean = mean / torch.sqrt(self.alphas.to(xt.device)[t]) |
|
|
|
|
|
if t == 0: |
|
|
return mean, x0 |
|
|
else: |
|
|
variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t]) |
|
|
variance = variance * self.betas.to(xt.device)[t] |
|
|
sigma = variance ** 0.5 |
|
|
z = torch.randn(xt.shape).to(xt.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return mean + sigma * z, x0 |
|
|
|
|
|
|
|
|
import torch |
|
|
import math |
|
|
|
|
|
class CosineNoiseScheduler: |
|
|
r""" |
|
|
Class for the cosine noise scheduler, often used in DDPM-based models. |
|
|
""" |
|
|
|
|
|
def __init__(self, num_timesteps, s=0.008): |
|
|
self.num_timesteps = num_timesteps |
|
|
self.s = s |
|
|
|
|
|
|
|
|
def cosine_schedule(t): |
|
|
return math.cos((t / self.num_timesteps + s) / (1 + s) * math.pi / 2) ** 2 |
|
|
|
|
|
|
|
|
self.alphas = torch.tensor([cosine_schedule(t) for t in range(num_timesteps)]) |
|
|
self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0) |
|
|
self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod) |
|
|
self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod) |
|
|
|
|
|
def add_noise(self, original, noise, t): |
|
|
original_shape = original.shape |
|
|
batch_size = original_shape[0] |
|
|
|
|
|
sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size) |
|
|
sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size) |
|
|
|
|
|
for _ in range(len(original_shape) - 1): |
|
|
sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1) |
|
|
for _ in range(len(original_shape) - 1): |
|
|
sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1) |
|
|
|
|
|
return (sqrt_alpha_cum_prod * original + sqrt_one_minus_alpha_cum_prod * noise) |
|
|
|
|
|
def sample_prev_timestep(self, xt, noise_pred, t): |
|
|
x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) / |
|
|
torch.sqrt(self.alpha_cum_prod.to(xt.device)[t])) |
|
|
x0 = torch.clamp(x0, -1., 1.) |
|
|
|
|
|
mean = xt - ((1 - self.alphas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t]) |
|
|
mean = mean / torch.sqrt(self.alphas.to(xt.device)[t]) |
|
|
|
|
|
if t == 0: |
|
|
return mean, x0 |
|
|
else: |
|
|
variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t]) |
|
|
variance = variance * (1 - self.alphas.to(xt.device)[t]) |
|
|
sigma = variance ** 0.5 |
|
|
z = torch.randn(xt.shape).to(xt.device) |
|
|
return mean + sigma * z, x0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
def get_time_embedding(time_steps, temb_dim): |
|
|
r""" |
|
|
Convert time steps tensor into an embedding using the |
|
|
sinusoidal time embedding formula |
|
|
:param time_steps: 1D tensor of length batch size |
|
|
:param temb_dim: Dimension of the embedding |
|
|
:return: BxD embedding representation of B time steps |
|
|
""" |
|
|
assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" |
|
|
|
|
|
|
|
|
factor = 10000 ** ((torch.arange( |
|
|
start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2)) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor |
|
|
t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) |
|
|
return t_emb |
|
|
|
|
|
|
|
|
class DownBlock(nn.Module): |
|
|
r""" |
|
|
Down conv block with attention. |
|
|
Sequence of following block |
|
|
1. Resnet block with time embedding |
|
|
2. Attention block |
|
|
3. Downsample |
|
|
""" |
|
|
|
|
|
def __init__(self, in_channels, out_channels, t_emb_dim, |
|
|
down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None): |
|
|
super().__init__() |
|
|
self.num_layers = num_layers |
|
|
self.down_sample = down_sample |
|
|
self.attn = attn |
|
|
self.context_dim = context_dim |
|
|
self.cross_attn = cross_attn |
|
|
self.t_emb_dim = t_emb_dim |
|
|
self.resnet_conv_first = nn.ModuleList( |
|
|
[ |
|
|
nn.Sequential( |
|
|
nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), |
|
|
nn.SiLU(), |
|
|
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, |
|
|
kernel_size=3, stride=1, padding=1), |
|
|
) |
|
|
for i in range(num_layers) |
|
|
] |
|
|
) |
|
|
if self.t_emb_dim is not None: |
|
|
self.t_emb_layers = nn.ModuleList([ |
|
|
nn.Sequential( |
|
|
nn.SiLU(), |
|
|
nn.Linear(self.t_emb_dim, out_channels) |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
]) |
|
|
self.resnet_conv_second = nn.ModuleList( |
|
|
[ |
|
|
nn.Sequential( |
|
|
nn.GroupNorm(norm_channels, out_channels), |
|
|
nn.SiLU(), |
|
|
nn.Conv2d(out_channels, out_channels, |
|
|
kernel_size=3, stride=1, padding=1), |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
if self.attn: |
|
|
self.attention_norms = nn.ModuleList( |
|
|
[nn.GroupNorm(norm_channels, out_channels) |
|
|
for _ in range(num_layers)] |
|
|
) |
|
|
|
|
|
self.attentions = nn.ModuleList( |
|
|
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True) |
|
|
for _ in range(num_layers)] |
|
|
) |
|
|
|
|
|
if self.cross_attn: |
|
|
assert context_dim is not None, "Context Dimension must be passed for cross attention" |
|
|
self.cross_attention_norms = nn.ModuleList( |
|
|
[nn.GroupNorm(norm_channels, out_channels) |
|
|
for _ in range(num_layers)] |
|
|
) |
|
|
self.cross_attentions = nn.ModuleList( |
|
|
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True) |
|
|
for _ in range(num_layers)] |
|
|
) |
|
|
self.context_proj = nn.ModuleList( |
|
|
[nn.Linear(context_dim, out_channels) |
|
|
for _ in range(num_layers)] |
|
|
) |
|
|
|
|
|
self.residual_input_conv = nn.ModuleList( |
|
|
[ |
|
|
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) |
|
|
for i in range(num_layers) |
|
|
] |
|
|
) |
|
|
self.down_sample_conv = nn.Conv2d(out_channels, out_channels, |
|
|
4, 2, 1) if self.down_sample else nn.Identity() |
|
|
|
|
|
def forward(self, x, t_emb=None, context=None): |
|
|
out = x |
|
|
for i in range(self.num_layers): |
|
|
|
|
|
resnet_input = out |
|
|
out = self.resnet_conv_first[i](out) |
|
|
if self.t_emb_dim is not None: |
|
|
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] |
|
|
out = self.resnet_conv_second[i](out) |
|
|
out = out + self.residual_input_conv[i](resnet_input) |
|
|
|
|
|
if self.attn: |
|
|
|
|
|
batch_size, channels, h, w = out.shape |
|
|
in_attn = out.reshape(batch_size, channels, h * w) |
|
|
in_attn = self.attention_norms[i](in_attn) |
|
|
in_attn = in_attn.transpose(1, 2) |
|
|
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) |
|
|
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
|
|
out = out + out_attn |
|
|
|
|
|
if self.cross_attn: |
|
|
assert context is not None, "context cannot be None if cross attention layers are used" |
|
|
batch_size, channels, h, w = out.shape |
|
|
in_attn = out.reshape(batch_size, channels, h * w) |
|
|
in_attn = self.cross_attention_norms[i](in_attn) |
|
|
in_attn = in_attn.transpose(1, 2) |
|
|
assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim |
|
|
context_proj = self.context_proj[i](context) |
|
|
out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) |
|
|
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
|
|
out = out + out_attn |
|
|
|
|
|
|
|
|
out = self.down_sample_conv(out) |
|
|
return out |
|
|
|
|
|
|
|
|
class MidBlock(nn.Module): |
|
|
r""" |
|
|
Mid conv block with attention. |
|
|
Sequence of following blocks |
|
|
1. Resnet block with time embedding |
|
|
2. Attention block |
|
|
3. Resnet block with time embedding |
|
|
""" |
|
|
|
|
|
def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, context_dim=None): |
|
|
super().__init__() |
|
|
self.num_layers = num_layers |
|
|
self.t_emb_dim = t_emb_dim |
|
|
self.context_dim = context_dim |
|
|
self.cross_attn = cross_attn |
|
|
self.resnet_conv_first = nn.ModuleList( |
|
|
[ |
|
|
nn.Sequential( |
|
|
nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), |
|
|
nn.SiLU(), |
|
|
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, |
|
|
padding=1), |
|
|
) |
|
|
for i in range(num_layers + 1) |
|
|
] |
|
|
) |
|
|
|
|
|
if self.t_emb_dim is not None: |
|
|
self.t_emb_layers = nn.ModuleList([ |
|
|
nn.Sequential( |
|
|
nn.SiLU(), |
|
|
nn.Linear(t_emb_dim, out_channels) |
|
|
) |
|
|
for _ in range(num_layers + 1) |
|
|
]) |
|
|
self.resnet_conv_second = nn.ModuleList( |
|
|
[ |
|
|
nn.Sequential( |
|
|
nn.GroupNorm(norm_channels, out_channels), |
|
|
nn.SiLU(), |
|
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), |
|
|
) |
|
|
for _ in range(num_layers + 1) |
|
|
] |
|
|
) |
|
|
|
|
|
self.attention_norms = nn.ModuleList( |
|
|
[nn.GroupNorm(norm_channels, out_channels) |
|
|
for _ in range(num_layers)] |
|
|
) |
|
|
|
|
|
self.attentions = nn.ModuleList( |
|
|
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True) |
|
|
for _ in range(num_layers)] |
|
|
) |
|
|
if self.cross_attn: |
|
|
assert context_dim is not None, "Context Dimension must be passed for cross attention" |
|
|
self.cross_attention_norms = nn.ModuleList( |
|
|
[nn.GroupNorm(norm_channels, out_channels) |
|
|
for _ in range(num_layers)] |
|
|
) |
|
|
self.cross_attentions = nn.ModuleList( |
|
|
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True) |
|
|
for _ in range(num_layers)] |
|
|
) |
|
|
self.context_proj = nn.ModuleList( |
|
|
[nn.Linear(context_dim, out_channels) |
|
|
for _ in range(num_layers)] |
|
|
) |
|
|
self.residual_input_conv = nn.ModuleList( |
|
|
[ |
|
|
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) |
|
|
for i in range(num_layers + 1) |
|
|
] |
|
|
) |
|
|
|
|
|
def forward(self, x, t_emb=None, context=None): |
|
|
out = x |
|
|
|
|
|
|
|
|
resnet_input = out |
|
|
out = self.resnet_conv_first[0](out) |
|
|
if self.t_emb_dim is not None: |
|
|
out = out + self.t_emb_layers[0](t_emb)[:, :, None, None] |
|
|
out = self.resnet_conv_second[0](out) |
|
|
out = out + self.residual_input_conv[0](resnet_input) |
|
|
|
|
|
for i in range(self.num_layers): |
|
|
|
|
|
batch_size, channels, h, w = out.shape |
|
|
in_attn = out.reshape(batch_size, channels, h * w) |
|
|
in_attn = self.attention_norms[i](in_attn) |
|
|
in_attn = in_attn.transpose(1, 2) |
|
|
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) |
|
|
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
|
|
out = out + out_attn |
|
|
|
|
|
if self.cross_attn: |
|
|
assert context is not None, "context cannot be None if cross attention layers are used" |
|
|
batch_size, channels, h, w = out.shape |
|
|
in_attn = out.reshape(batch_size, channels, h * w) |
|
|
in_attn = self.cross_attention_norms[i](in_attn) |
|
|
in_attn = in_attn.transpose(1, 2) |
|
|
assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim |
|
|
context_proj = self.context_proj[i](context) |
|
|
out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) |
|
|
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
|
|
out = out + out_attn |
|
|
|
|
|
|
|
|
|
|
|
resnet_input = out |
|
|
out = self.resnet_conv_first[i + 1](out) |
|
|
if self.t_emb_dim is not None: |
|
|
out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None] |
|
|
out = self.resnet_conv_second[i + 1](out) |
|
|
out = out + self.residual_input_conv[i + 1](resnet_input) |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
class UpBlock(nn.Module): |
|
|
r""" |
|
|
Up conv block with attention. |
|
|
Sequence of following blocks |
|
|
1. Upsample |
|
|
1. Concatenate Down block output |
|
|
2. Resnet block with time embedding |
|
|
3. Attention Block |
|
|
""" |
|
|
|
|
|
def __init__(self, in_channels, out_channels, t_emb_dim, |
|
|
up_sample, num_heads, num_layers, attn, norm_channels): |
|
|
super().__init__() |
|
|
self.num_layers = num_layers |
|
|
self.up_sample = up_sample |
|
|
self.t_emb_dim = t_emb_dim |
|
|
self.attn = attn |
|
|
self.resnet_conv_first = nn.ModuleList( |
|
|
[ |
|
|
nn.Sequential( |
|
|
nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), |
|
|
nn.SiLU(), |
|
|
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, |
|
|
padding=1), |
|
|
) |
|
|
for i in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
if self.t_emb_dim is not None: |
|
|
self.t_emb_layers = nn.ModuleList([ |
|
|
nn.Sequential( |
|
|
nn.SiLU(), |
|
|
nn.Linear(t_emb_dim, out_channels) |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
]) |
|
|
|
|
|
self.resnet_conv_second = nn.ModuleList( |
|
|
[ |
|
|
nn.Sequential( |
|
|
nn.GroupNorm(norm_channels, out_channels), |
|
|
nn.SiLU(), |
|
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
if self.attn: |
|
|
self.attention_norms = nn.ModuleList( |
|
|
[ |
|
|
nn.GroupNorm(norm_channels, out_channels) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
self.attentions = nn.ModuleList( |
|
|
[ |
|
|
nn.MultiheadAttention(out_channels, num_heads, batch_first=True) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
self.residual_input_conv = nn.ModuleList( |
|
|
[ |
|
|
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) |
|
|
for i in range(num_layers) |
|
|
] |
|
|
) |
|
|
self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels, |
|
|
4, 2, 1) \ |
|
|
if self.up_sample else nn.Identity() |
|
|
|
|
|
def forward(self, x, out_down=None, t_emb=None): |
|
|
|
|
|
x = self.up_sample_conv(x) |
|
|
|
|
|
|
|
|
if out_down is not None: |
|
|
x = torch.cat([x, out_down], dim=1) |
|
|
|
|
|
out = x |
|
|
for i in range(self.num_layers): |
|
|
|
|
|
resnet_input = out |
|
|
out = self.resnet_conv_first[i](out) |
|
|
if self.t_emb_dim is not None: |
|
|
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] |
|
|
out = self.resnet_conv_second[i](out) |
|
|
out = out + self.residual_input_conv[i](resnet_input) |
|
|
|
|
|
|
|
|
if self.attn: |
|
|
batch_size, channels, h, w = out.shape |
|
|
in_attn = out.reshape(batch_size, channels, h * w) |
|
|
in_attn = self.attention_norms[i](in_attn) |
|
|
in_attn = in_attn.transpose(1, 2) |
|
|
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) |
|
|
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
|
|
out = out + out_attn |
|
|
return out |
|
|
|
|
|
|
|
|
class UpBlockUnet(nn.Module): |
|
|
r""" |
|
|
Up conv block with attention. |
|
|
Sequence of following blocks |
|
|
1. Upsample |
|
|
1. Concatenate Down block output |
|
|
2. Resnet block with time embedding |
|
|
3. Attention Block |
|
|
""" |
|
|
|
|
|
def __init__(self, in_channels, out_channels, t_emb_dim, up_sample, |
|
|
num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None): |
|
|
super().__init__() |
|
|
self.num_layers = num_layers |
|
|
self.up_sample = up_sample |
|
|
self.t_emb_dim = t_emb_dim |
|
|
self.cross_attn = cross_attn |
|
|
self.context_dim = context_dim |
|
|
self.resnet_conv_first = nn.ModuleList( |
|
|
[ |
|
|
nn.Sequential( |
|
|
nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), |
|
|
nn.SiLU(), |
|
|
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, |
|
|
padding=1), |
|
|
) |
|
|
for i in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
if self.t_emb_dim is not None: |
|
|
self.t_emb_layers = nn.ModuleList([ |
|
|
nn.Sequential( |
|
|
nn.SiLU(), |
|
|
nn.Linear(t_emb_dim, out_channels) |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
]) |
|
|
|
|
|
self.resnet_conv_second = nn.ModuleList( |
|
|
[ |
|
|
nn.Sequential( |
|
|
nn.GroupNorm(norm_channels, out_channels), |
|
|
nn.SiLU(), |
|
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
self.attention_norms = nn.ModuleList( |
|
|
[ |
|
|
nn.GroupNorm(norm_channels, out_channels) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
self.attentions = nn.ModuleList( |
|
|
[ |
|
|
nn.MultiheadAttention(out_channels, num_heads, batch_first=True) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
if self.cross_attn: |
|
|
assert context_dim is not None, "Context Dimension must be passed for cross attention" |
|
|
self.cross_attention_norms = nn.ModuleList( |
|
|
[nn.GroupNorm(norm_channels, out_channels) |
|
|
for _ in range(num_layers)] |
|
|
) |
|
|
self.cross_attentions = nn.ModuleList( |
|
|
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True) |
|
|
for _ in range(num_layers)] |
|
|
) |
|
|
self.context_proj = nn.ModuleList( |
|
|
[nn.Linear(context_dim, out_channels) |
|
|
for _ in range(num_layers)] |
|
|
) |
|
|
self.residual_input_conv = nn.ModuleList( |
|
|
[ |
|
|
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) |
|
|
for i in range(num_layers) |
|
|
] |
|
|
) |
|
|
self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, |
|
|
4, 2, 1) \ |
|
|
if self.up_sample else nn.Identity() |
|
|
|
|
|
def forward(self, x, out_down=None, t_emb=None, context=None): |
|
|
x = self.up_sample_conv(x) |
|
|
if out_down is not None: |
|
|
x = torch.cat([x, out_down], dim=1) |
|
|
|
|
|
out = x |
|
|
for i in range(self.num_layers): |
|
|
|
|
|
resnet_input = out |
|
|
out = self.resnet_conv_first[i](out) |
|
|
if self.t_emb_dim is not None: |
|
|
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] |
|
|
out = self.resnet_conv_second[i](out) |
|
|
out = out + self.residual_input_conv[i](resnet_input) |
|
|
|
|
|
batch_size, channels, h, w = out.shape |
|
|
in_attn = out.reshape(batch_size, channels, h * w) |
|
|
in_attn = self.attention_norms[i](in_attn) |
|
|
in_attn = in_attn.transpose(1, 2) |
|
|
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) |
|
|
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
|
|
out = out + out_attn |
|
|
|
|
|
if self.cross_attn: |
|
|
assert context is not None, "context cannot be None if cross attention layers are used" |
|
|
batch_size, channels, h, w = out.shape |
|
|
in_attn = out.reshape(batch_size, channels, h * w) |
|
|
in_attn = self.cross_attention_norms[i](in_attn) |
|
|
in_attn = in_attn.transpose(1, 2) |
|
|
assert len(context.shape) == 3, \ |
|
|
"Context shape does not match B,_,CONTEXT_DIM" |
|
|
assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim,\ |
|
|
"Context shape does not match B,_,CONTEXT_DIM" |
|
|
context_proj = self.context_proj[i](context) |
|
|
out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) |
|
|
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) |
|
|
out = out + out_attn |
|
|
|
|
|
return out |
|
|
|
|
|
"""Vqvae""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class VQVAE(nn.Module): |
|
|
def __init__(self, im_channels, model_config): |
|
|
super().__init__() |
|
|
self.down_channels = model_config.down_channels |
|
|
self.mid_channels = model_config.mid_channels |
|
|
self.down_sample = model_config.down_sample |
|
|
self.num_down_layers = model_config.num_down_layers |
|
|
self.num_mid_layers = model_config.num_mid_layers |
|
|
self.num_up_layers = model_config.num_up_layers |
|
|
|
|
|
|
|
|
self.attns = model_config.attn_down |
|
|
|
|
|
|
|
|
self.z_channels = model_config.z_channels |
|
|
self.codebook_size = model_config.codebook_size |
|
|
self.norm_channels = model_config.norm_channels |
|
|
self.num_heads = model_config.num_heads |
|
|
|
|
|
|
|
|
assert self.mid_channels[0] == self.down_channels[-1] |
|
|
assert self.mid_channels[-1] == self.down_channels[-1] |
|
|
assert len(self.down_sample) == len(self.down_channels) - 1 |
|
|
assert len(self.attns) == len(self.down_channels) - 1 |
|
|
|
|
|
|
|
|
|
|
|
self.up_sample = list(reversed(self.down_sample)) |
|
|
|
|
|
|
|
|
self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1)) |
|
|
|
|
|
|
|
|
self.encoder_layers = nn.ModuleList([]) |
|
|
for i in range(len(self.down_channels) - 1): |
|
|
self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1], |
|
|
t_emb_dim=None, down_sample=self.down_sample[i], |
|
|
num_heads=self.num_heads, |
|
|
num_layers=self.num_down_layers, |
|
|
attn=self.attns[i], |
|
|
norm_channels=self.norm_channels)) |
|
|
|
|
|
self.encoder_mids = nn.ModuleList([]) |
|
|
for i in range(len(self.mid_channels) - 1): |
|
|
self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1], |
|
|
t_emb_dim=None, |
|
|
num_heads=self.num_heads, |
|
|
num_layers=self.num_mid_layers, |
|
|
norm_channels=self.norm_channels)) |
|
|
|
|
|
self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1]) |
|
|
self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], self.z_channels, kernel_size=3, padding=1) |
|
|
|
|
|
|
|
|
self.pre_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) |
|
|
|
|
|
|
|
|
self.embedding = nn.Embedding(self.codebook_size, self.z_channels) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) |
|
|
self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1)) |
|
|
|
|
|
|
|
|
self.decoder_mids = nn.ModuleList([]) |
|
|
for i in reversed(range(1, len(self.mid_channels))): |
|
|
self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1], |
|
|
t_emb_dim=None, |
|
|
num_heads=self.num_heads, |
|
|
num_layers=self.num_mid_layers, |
|
|
norm_channels=self.norm_channels)) |
|
|
|
|
|
self.decoder_layers = nn.ModuleList([]) |
|
|
for i in reversed(range(1, len(self.down_channels))): |
|
|
self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1], |
|
|
t_emb_dim=None, up_sample=self.down_sample[i - 1], |
|
|
num_heads=self.num_heads, |
|
|
num_layers=self.num_up_layers, |
|
|
attn=self.attns[i-1], |
|
|
norm_channels=self.norm_channels)) |
|
|
|
|
|
self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0]) |
|
|
self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1) |
|
|
|
|
|
def quantize(self, x): |
|
|
B, C, H, W = x.shape |
|
|
|
|
|
|
|
|
x = x.permute(0, 2, 3, 1) |
|
|
|
|
|
|
|
|
x = x.reshape(x.size(0), -1, x.size(-1)) |
|
|
|
|
|
|
|
|
|
|
|
dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1))) |
|
|
|
|
|
min_encoding_indices = torch.argmin(dist, dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1)) |
|
|
|
|
|
|
|
|
x = x.reshape((-1, x.size(-1))) |
|
|
commmitment_loss = torch.mean((quant_out.detach() - x) ** 2) |
|
|
codebook_loss = torch.mean((quant_out - x.detach()) ** 2) |
|
|
quantize_losses = { |
|
|
'codebook_loss': codebook_loss, |
|
|
'commitment_loss': commmitment_loss |
|
|
} |
|
|
|
|
|
quant_out = x + (quant_out - x).detach() |
|
|
|
|
|
|
|
|
quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2) |
|
|
min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1))) |
|
|
return quant_out, quantize_losses, min_encoding_indices |
|
|
|
|
|
def encode(self, x): |
|
|
out = self.encoder_conv_in(x) |
|
|
for idx, down in enumerate(self.encoder_layers): |
|
|
out = down(out) |
|
|
for mid in self.encoder_mids: |
|
|
out = mid(out) |
|
|
out = self.encoder_norm_out(out) |
|
|
out = nn.SiLU()(out) |
|
|
out = self.encoder_conv_out(out) |
|
|
out = self.pre_quant_conv(out) |
|
|
out, quant_losses, _ = self.quantize(out) |
|
|
return out, quant_losses |
|
|
|
|
|
def decode(self, z): |
|
|
out = z |
|
|
out = self.post_quant_conv(out) |
|
|
out = self.decoder_conv_in(out) |
|
|
for mid in self.decoder_mids: |
|
|
out = mid(out) |
|
|
for idx, up in enumerate(self.decoder_layers): |
|
|
out = up(out) |
|
|
|
|
|
out = self.decoder_norm_out(out) |
|
|
out = nn.SiLU()(out) |
|
|
out = self.decoder_conv_out(out) |
|
|
return out |
|
|
|
|
|
def forward(self, x): |
|
|
z, quant_losses = self.encode(x) |
|
|
out = self.decode(z) |
|
|
return out, z, quant_losses |
|
|
|
|
|
"""Vae""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class VAE(nn.Module): |
|
|
def __init__(self, im_channels, model_config): |
|
|
super().__init__() |
|
|
self.down_channels = model_config['down_channels'] |
|
|
self.mid_channels = model_config['mid_channels'] |
|
|
self.down_sample = model_config['down_sample'] |
|
|
self.num_down_layers = model_config['num_down_layers'] |
|
|
self.num_mid_layers = model_config['num_mid_layers'] |
|
|
self.num_up_layers = model_config['num_up_layers'] |
|
|
|
|
|
|
|
|
self.attns = model_config['attn_down'] |
|
|
|
|
|
|
|
|
self.z_channels = model_config['z_channels'] |
|
|
self.norm_channels = model_config['norm_channels'] |
|
|
self.num_heads = model_config['num_heads'] |
|
|
|
|
|
|
|
|
assert self.mid_channels[0] == self.down_channels[-1] |
|
|
assert self.mid_channels[-1] == self.down_channels[-1] |
|
|
assert len(self.down_sample) == len(self.down_channels) - 1 |
|
|
assert len(self.attns) == len(self.down_channels) - 1 |
|
|
|
|
|
|
|
|
|
|
|
self.up_sample = list(reversed(self.down_sample)) |
|
|
|
|
|
|
|
|
self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1)) |
|
|
|
|
|
|
|
|
self.encoder_layers = nn.ModuleList([]) |
|
|
for i in range(len(self.down_channels) - 1): |
|
|
self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1], |
|
|
t_emb_dim=None, down_sample=self.down_sample[i], |
|
|
num_heads=self.num_heads, |
|
|
num_layers=self.num_down_layers, |
|
|
attn=self.attns[i], |
|
|
norm_channels=self.norm_channels)) |
|
|
|
|
|
self.encoder_mids = nn.ModuleList([]) |
|
|
for i in range(len(self.mid_channels) - 1): |
|
|
self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1], |
|
|
t_emb_dim=None, |
|
|
num_heads=self.num_heads, |
|
|
num_layers=self.num_mid_layers, |
|
|
norm_channels=self.norm_channels)) |
|
|
|
|
|
self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1]) |
|
|
self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], 2*self.z_channels, kernel_size=3, padding=1) |
|
|
|
|
|
|
|
|
self.pre_quant_conv = nn.Conv2d(2*self.z_channels, 2*self.z_channels, kernel_size=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) |
|
|
self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1)) |
|
|
|
|
|
|
|
|
self.decoder_mids = nn.ModuleList([]) |
|
|
for i in reversed(range(1, len(self.mid_channels))): |
|
|
self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1], |
|
|
t_emb_dim=None, |
|
|
num_heads=self.num_heads, |
|
|
num_layers=self.num_mid_layers, |
|
|
norm_channels=self.norm_channels)) |
|
|
|
|
|
self.decoder_layers = nn.ModuleList([]) |
|
|
for i in reversed(range(1, len(self.down_channels))): |
|
|
self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1], |
|
|
t_emb_dim=None, up_sample=self.down_sample[i - 1], |
|
|
num_heads=self.num_heads, |
|
|
num_layers=self.num_up_layers, |
|
|
attn=self.attns[i - 1], |
|
|
norm_channels=self.norm_channels)) |
|
|
|
|
|
self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0]) |
|
|
self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1) |
|
|
|
|
|
def encode(self, x): |
|
|
out = self.encoder_conv_in(x) |
|
|
for idx, down in enumerate(self.encoder_layers): |
|
|
out = down(out) |
|
|
for mid in self.encoder_mids: |
|
|
out = mid(out) |
|
|
out = self.encoder_norm_out(out) |
|
|
out = nn.SiLU()(out) |
|
|
out = self.encoder_conv_out(out) |
|
|
out = self.pre_quant_conv(out) |
|
|
mean, logvar = torch.chunk(out, 2, dim=1) |
|
|
std = torch.exp(0.5 * logvar) |
|
|
sample = mean + std * torch.randn(mean.shape).to(device=x.device) |
|
|
return sample, out |
|
|
|
|
|
def decode(self, z): |
|
|
out = z |
|
|
out = self.post_quant_conv(out) |
|
|
out = self.decoder_conv_in(out) |
|
|
for mid in self.decoder_mids: |
|
|
out = mid(out) |
|
|
for idx, up in enumerate(self.decoder_layers): |
|
|
out = up(out) |
|
|
|
|
|
out = self.decoder_norm_out(out) |
|
|
out = nn.SiLU()(out) |
|
|
out = self.decoder_conv_out(out) |
|
|
return out |
|
|
|
|
|
def forward(self, x): |
|
|
z, encoder_output = self.encode(x) |
|
|
out = self.decode(z) |
|
|
return out, encoder_output |
|
|
|
|
|
"""Discriminator""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class Discriminator(nn.Module): |
|
|
r""" |
|
|
PatchGAN Discriminator. |
|
|
Rather than taking IMG_CHANNELSxIMG_HxIMG_W all the way to |
|
|
1 scalar value , we instead predict grid of values. |
|
|
Where each grid is prediction of how likely |
|
|
the discriminator thinks that the image patch corresponding |
|
|
to the grid cell is real |
|
|
""" |
|
|
|
|
|
def __init__(self, im_channels=3, |
|
|
conv_channels=[64, 128, 256], |
|
|
kernels=[4,4,4,4], |
|
|
strides=[2,2,2,1], |
|
|
paddings=[1,1,1,1]): |
|
|
super().__init__() |
|
|
self.im_channels = im_channels |
|
|
activation = nn.LeakyReLU(0.2) |
|
|
layers_dim = [self.im_channels] + conv_channels + [1] |
|
|
self.layers = nn.ModuleList([ |
|
|
nn.Sequential( |
|
|
nn.Conv2d(layers_dim[i], layers_dim[i + 1], |
|
|
kernel_size=kernels[i], |
|
|
stride=strides[i], |
|
|
padding=paddings[i], |
|
|
bias=False if i !=0 else True), |
|
|
nn.BatchNorm2d(layers_dim[i + 1]) if i != len(layers_dim) - 2 and i != 0 else nn.Identity(), |
|
|
activation if i != len(layers_dim) - 2 else nn.Identity() |
|
|
) |
|
|
for i in range(len(layers_dim) - 1) |
|
|
]) |
|
|
|
|
|
def forward(self, x): |
|
|
out = x |
|
|
for layer in self.layers: |
|
|
out = layer(out) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import glob |
|
|
import os |
|
|
import torchvision |
|
|
from PIL import Image |
|
|
from tqdm import tqdm, trange |
|
|
|
|
|
from torch.utils.data.dataset import Dataset |
|
|
|
|
|
import pickle |
|
|
import glob |
|
|
import os |
|
|
import torch |
|
|
|
|
|
|
|
|
def load_latents(latent_path): |
|
|
r""" |
|
|
Simple utility to save latents to speed up ldm training |
|
|
:param latent_path: |
|
|
:return: |
|
|
""" |
|
|
latent_maps = {} |
|
|
for fname in glob.glob(os.path.join(latent_path, '*.pkl')): |
|
|
s = pickle.load(open(fname, 'rb')) |
|
|
for k, v in s.items(): |
|
|
latent_maps[k] = v[0] |
|
|
return latent_maps |
|
|
|
|
|
|
|
|
def drop_text_condition(text_embed, im, empty_text_embed, text_drop_prob): |
|
|
if text_drop_prob > 0: |
|
|
text_drop_mask = torch.zeros((im.shape[0]), device=im.device).float().uniform_(0, |
|
|
1) < text_drop_prob |
|
|
assert empty_text_embed is not None, ("Text Conditioning required as well as" |
|
|
" text dropping but empty text representation not created") |
|
|
text_embed[text_drop_mask, :, :] = empty_text_embed[0] |
|
|
return text_embed |
|
|
|
|
|
|
|
|
def drop_image_condition(image_condition, im, im_drop_prob): |
|
|
if im_drop_prob > 0: |
|
|
im_drop_mask = torch.zeros((im.shape[0], 1, 1, 1), device=im.device).float().uniform_(0, |
|
|
1) > im_drop_prob |
|
|
return image_condition * im_drop_mask |
|
|
else: |
|
|
return image_condition |
|
|
|
|
|
|
|
|
def drop_class_condition(class_condition, class_drop_prob, im): |
|
|
if class_drop_prob > 0: |
|
|
class_drop_mask = torch.zeros((im.shape[0], 1), device=im.device).float().uniform_(0, |
|
|
1) > class_drop_prob |
|
|
return class_condition * class_drop_mask |
|
|
else: |
|
|
return class_condition |
|
|
|
|
|
|
|
|
class MnistDataset(Dataset): |
|
|
r""" |
|
|
Nothing special here. Just a simple dataset class for mnist images. |
|
|
Created a dataset class rather using torchvision to allow |
|
|
replacement with any other image dataset |
|
|
""" |
|
|
|
|
|
def __init__(self, split, im_path, im_size, im_channels, |
|
|
use_latents=False, latent_path=None, condition_config=None): |
|
|
r""" |
|
|
Init method for initializing the dataset properties |
|
|
:param split: train/test to locate the image files |
|
|
:param im_path: root folder of images |
|
|
:param im_ext: image extension. assumes all |
|
|
images would be this type. |
|
|
""" |
|
|
self.split = split |
|
|
self.im_size = im_size |
|
|
self.im_channels = im_channels |
|
|
|
|
|
|
|
|
self.latent_maps = None |
|
|
self.use_latents = False |
|
|
|
|
|
|
|
|
self.condition_types = [] if condition_config is None else condition_config['condition_types'] |
|
|
|
|
|
self.images, self.labels = self.load_images(im_path) |
|
|
|
|
|
|
|
|
if use_latents and latent_path is not None: |
|
|
latent_maps = load_latents(latent_path) |
|
|
if len(latent_maps) == len(self.images): |
|
|
self.use_latents = True |
|
|
self.latent_maps = latent_maps |
|
|
print('Found {} latents'.format(len(self.latent_maps))) |
|
|
else: |
|
|
print('Latents not found') |
|
|
|
|
|
def load_images(self, im_path): |
|
|
r""" |
|
|
Gets all images from the path specified |
|
|
and stacks them all up |
|
|
:param im_path: |
|
|
:return: |
|
|
""" |
|
|
assert os.path.exists(im_path), "images path {} does not exist".format(im_path) |
|
|
ims = [] |
|
|
labels = [] |
|
|
for d_name in tqdm(os.listdir(im_path)): |
|
|
fnames = glob.glob(os.path.join(im_path, d_name, '*.{}'.format('png'))) |
|
|
fnames += glob.glob(os.path.join(im_path, d_name, '*.{}'.format('jpg'))) |
|
|
fnames += glob.glob(os.path.join(im_path, d_name, '*.{}'.format('jpeg'))) |
|
|
for fname in fnames: |
|
|
ims.append(fname) |
|
|
if 'class' in self.condition_types: |
|
|
labels.append(int(d_name)) |
|
|
print('Found {} images for split {}'.format(len(ims), self.split)) |
|
|
return ims, labels |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.images) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
|
|
|
cond_inputs = {} |
|
|
if 'class' in self.condition_types: |
|
|
cond_inputs['class'] = self.labels[index] |
|
|
|
|
|
|
|
|
if self.use_latents: |
|
|
latent = self.latent_maps[self.images[index]] |
|
|
if len(self.condition_types) == 0: |
|
|
return latent |
|
|
else: |
|
|
return latent, cond_inputs |
|
|
else: |
|
|
im = Image.open(self.images[index]) |
|
|
im_tensor = torchvision.transforms.ToTensor()(im) |
|
|
|
|
|
|
|
|
im_tensor = (2 * im_tensor) - 1 |
|
|
if len(self.condition_types) == 0: |
|
|
return im_tensor |
|
|
else: |
|
|
return im_tensor, cond_inputs |
|
|
|
|
|
|
|
|
class AnimeFaceDataset(Dataset): |
|
|
def __init__(self, split, im_path, im_size, im_channels, |
|
|
use_latents=False, latent_path=None, condition_config=None): |
|
|
|
|
|
self.split = split |
|
|
self.im_size = im_size |
|
|
self.im_channels = im_channels |
|
|
|
|
|
|
|
|
self.latent_maps = None |
|
|
self.use_latents = False |
|
|
|
|
|
|
|
|
self.condition_types = [] if condition_config is None else condition_config['condition_types'] |
|
|
|
|
|
self.images = self.load_images(im_path) |
|
|
|
|
|
|
|
|
if use_latents and latent_path is not None: |
|
|
latent_maps = load_latents(latent_path) |
|
|
if len(latent_maps) == len(self.images): |
|
|
self.use_latents = True |
|
|
self.latent_maps = latent_maps |
|
|
print('Found {} latents'.format(len(self.latent_maps))) |
|
|
else: |
|
|
print('Latents not found') |
|
|
|
|
|
def load_images(self, im_path): |
|
|
r""" |
|
|
Gets all images from the path specified |
|
|
and stacks them all up |
|
|
:param im_path: |
|
|
:return: |
|
|
""" |
|
|
assert os.path.exists(im_path), "images path {} does not exist".format(im_path) |
|
|
|
|
|
|
|
|
ims = [os.path.join(im_path, f) for f in os.listdir(im_path)] |
|
|
return ims |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.images) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.use_latents: |
|
|
latent = self.latent_maps[self.images[index]] |
|
|
if len(self.condition_types) == 0: |
|
|
return latent |
|
|
|
|
|
|
|
|
else: |
|
|
im = Image.open(self.images[index]) |
|
|
im_tensor = torchvision.transforms.Compose([ |
|
|
torchvision.transforms.Resize(self.im_size), |
|
|
torchvision.transforms.CenterCrop(self.im_size), |
|
|
torchvision.transforms.ToTensor(), |
|
|
])(im) |
|
|
im.close() |
|
|
|
|
|
|
|
|
|
|
|
im_tensor = (2 * im_tensor) - 1 |
|
|
if len(self.condition_types) == 0: |
|
|
return im_tensor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import glob |
|
|
import os |
|
|
import random |
|
|
import torch |
|
|
import torchvision |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from tqdm import tqdm |
|
|
from torch.utils.data.dataset import Dataset |
|
|
|
|
|
|
|
|
class CelebDataset(Dataset): |
|
|
def __init__(self, split, im_path, im_size, im_channels, |
|
|
use_latents=False, latent_path=None, condition_config=None): |
|
|
|
|
|
self.split = split |
|
|
self.im_size = im_size |
|
|
self.im_channels = im_channels |
|
|
|
|
|
|
|
|
self.latent_maps = None |
|
|
self.use_latents = False |
|
|
|
|
|
|
|
|
self.condition_types = [] if condition_config is None else condition_config['condition_types'] |
|
|
|
|
|
self.images = self.load_images(im_path) |
|
|
|
|
|
|
|
|
if use_latents and latent_path is not None: |
|
|
latent_maps = load_latents(latent_path) |
|
|
if len(latent_maps) == len(self.images): |
|
|
self.use_latents = True |
|
|
self.latent_maps = latent_maps |
|
|
print('Found {} latents'.format(len(self.latent_maps))) |
|
|
else: |
|
|
print('Latents not found') |
|
|
|
|
|
def load_images(self, im_path): |
|
|
r""" |
|
|
Gets all images from the path specified |
|
|
and stacks them all up |
|
|
:param im_path: |
|
|
:return: |
|
|
""" |
|
|
assert os.path.exists(im_path), "images path {} does not exist".format(im_path) |
|
|
|
|
|
|
|
|
ims = [os.path.join(im_path, f) for f in os.listdir(im_path)] |
|
|
return ims |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.images) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.use_latents: |
|
|
latent = self.latent_maps[self.images[index]] |
|
|
if len(self.condition_types) == 0: |
|
|
return latent |
|
|
|
|
|
|
|
|
else: |
|
|
im = Image.open(self.images[index]) |
|
|
im_tensor = torchvision.transforms.Compose([ |
|
|
|
|
|
torchvision.transforms.CenterCrop(self.im_size), |
|
|
torchvision.transforms.ToTensor(), |
|
|
])(im) |
|
|
im.close() |
|
|
|
|
|
|
|
|
|
|
|
im_tensor = (2 * im_tensor) - 1 |
|
|
if len(self.condition_types) == 0: |
|
|
return im_tensor |
|
|
|
|
|
|
|
|
import pandas as pd |
|
|
class CelebHairDataset(Dataset): |
|
|
def __init__(self, split, im_path, im_size, im_channels, |
|
|
use_latents=False, latent_path=None, condition_config=None): |
|
|
|
|
|
self.df = pd.read_csv("/home/taruntejaneurips23/Ashish/DDPM/hair_df_100.csv") |
|
|
self.split = split |
|
|
self.im_size = im_size |
|
|
self.im_channels = im_channels |
|
|
|
|
|
|
|
|
self.latent_maps = None |
|
|
self.use_latents = False |
|
|
|
|
|
|
|
|
self.condition_types = [] if condition_config is None else condition_config['condition_types'] |
|
|
|
|
|
self.images = self.load_images(im_path, self.df) |
|
|
|
|
|
|
|
|
if use_latents and latent_path is not None: |
|
|
latent_maps = load_latents(latent_path) |
|
|
if len(latent_maps) == len(self.images): |
|
|
self.use_latents = True |
|
|
self.latent_maps = latent_maps |
|
|
print('Found {} latents'.format(len(self.latent_maps))) |
|
|
else: |
|
|
print('Latents not found') |
|
|
|
|
|
def load_images(self, im_path, df): |
|
|
r""" |
|
|
Gets all images from the path specified |
|
|
and stacks them all up |
|
|
:param im_path: |
|
|
:return: |
|
|
""" |
|
|
assert os.path.exists(im_path), "images path {} does not exist".format(im_path) |
|
|
|
|
|
|
|
|
|
|
|
ims = [os.path.join(im_path, i) for i in df.image_id.values] |
|
|
return ims |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.images) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.use_latents: |
|
|
latent = self.latent_maps[self.images[index]] |
|
|
if len(self.condition_types) == 0: |
|
|
return latent |
|
|
|
|
|
|
|
|
else: |
|
|
im = Image.open(self.images[index]) |
|
|
im_tensor = torchvision.transforms.Compose([ |
|
|
|
|
|
torchvision.transforms.CenterCrop(self.im_size), |
|
|
torchvision.transforms.ToTensor(), |
|
|
])(im) |
|
|
im.close() |
|
|
|
|
|
|
|
|
|
|
|
im_tensor = (2 * im_tensor) - 1 |
|
|
if len(self.condition_types) == 0: |
|
|
return im_tensor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import yaml |
|
|
from dotdict import DotDict |
|
|
|
|
|
config_path = "/home/taruntejaneurips23/Ashish/DDPM/_5_ldm_celeba.yaml" |
|
|
with open(config_path, 'r') as file: |
|
|
Config = yaml.safe_load(file) |
|
|
|
|
|
|
|
|
Config = DotDict.from_dict(Config) |
|
|
dataset_config = Config.dataset_params |
|
|
diffusion_config = Config.diffusion_params |
|
|
model_config = Config.model_params |
|
|
train_config = Config.train_params |
|
|
|
|
|
import torch |
|
|
import os |
|
|
import random |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
from tqdm import tqdm |
|
|
from torch.optim import Adam |
|
|
from torch.utils.data import Dataset, TensorDataset, DataLoader |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
|
|
|
|
|
|
from torchvision.utils import make_grid |
|
|
|
|
|
def trainVAE(Config): |
|
|
|
|
|
dataset_config = Config.dataset_params |
|
|
autoencoder_config = Config.autoencoder_params |
|
|
train_config = Config.train_params |
|
|
|
|
|
|
|
|
seed = train_config.seed |
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
random.seed(seed) |
|
|
if device == 'cuda': |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
|
|
|
|
|
|
model = VQVAE(im_channels=dataset_config.im_channels, |
|
|
model_config=autoencoder_config).to(device) |
|
|
|
|
|
if os.path.exists(os.path.join(train_config.task_name, train_config.vqvae_autoencoder_ckpt_name)): |
|
|
print('Loaded vae checkpoint') |
|
|
model.load_state_dict(torch.load(os.path.join(train_config.task_name, train_config.vqvae_autoencoder_ckpt_name), |
|
|
map_location=device, weights_only=True)) |
|
|
|
|
|
|
|
|
im_dataset_cls = { |
|
|
'mnist': MnistDataset, |
|
|
'celebA': CelebDataset, |
|
|
'animeface': AnimeFaceDataset, |
|
|
'celebAhair': CelebHairDataset |
|
|
}.get(dataset_config.name) |
|
|
|
|
|
im_dataset = im_dataset_cls(split='train', |
|
|
im_path=dataset_config.im_path, |
|
|
im_size=dataset_config.im_size, |
|
|
im_channels=dataset_config.im_channels) |
|
|
|
|
|
data_loader = DataLoader(im_dataset, |
|
|
batch_size=train_config.autoencoder_batch_size, |
|
|
shuffle=True, |
|
|
num_workers=os.cpu_count(), |
|
|
pin_memory=True, |
|
|
drop_last=True, |
|
|
persistent_workers=True, pin_memory_device=device) |
|
|
|
|
|
|
|
|
if not os.path.exists(train_config.task_name): |
|
|
os.mkdir(train_config.task_name) |
|
|
|
|
|
num_epochs = train_config.autoencoder_epochs |
|
|
|
|
|
|
|
|
recon_criterion = torch.nn.MSELoss() |
|
|
|
|
|
disc_criterion = torch.nn.MSELoss() |
|
|
|
|
|
|
|
|
lpips_model = LPIPS().eval().to(device) |
|
|
discriminator = Discriminator(im_channels=dataset_config.im_channels).to(device) |
|
|
|
|
|
if os.path.exists(os.path.join(train_config.task_name, train_config.vqvae_discriminator_ckpt_name)): |
|
|
print('Loaded discriminator checkpoint') |
|
|
discriminator.load_state_dict(torch.load(os.path.join(train_config.task_name, train_config.vqvae_discriminator_ckpt_name), |
|
|
map_location=device, weights_only=True)) |
|
|
|
|
|
optimizer_d = Adam(discriminator.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999)) |
|
|
optimizer_g = Adam(model.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999)) |
|
|
|
|
|
disc_step_start = train_config.disc_start |
|
|
step_count = 0 |
|
|
|
|
|
|
|
|
|
|
|
acc_steps = train_config.autoencoder_acc_steps |
|
|
image_save_steps = train_config.autoencoder_img_save_steps |
|
|
img_save_count = 0 |
|
|
|
|
|
for epoch_idx in trange(num_epochs, desc='Training VQVAE'): |
|
|
recon_losses = [] |
|
|
codebook_losses = [] |
|
|
|
|
|
perceptual_losses = [] |
|
|
disc_losses = [] |
|
|
gen_losses = [] |
|
|
losses = [] |
|
|
|
|
|
optimizer_g.zero_grad() |
|
|
optimizer_d.zero_grad() |
|
|
|
|
|
|
|
|
for im in data_loader: |
|
|
step_count += 1 |
|
|
im = im.float().to(device) |
|
|
|
|
|
|
|
|
model_output = model(im) |
|
|
output, z, quantize_losses = model_output |
|
|
|
|
|
|
|
|
if step_count % image_save_steps == 0 or step_count == 1: |
|
|
sample_size = min(8, im.shape[0]) |
|
|
save_output = torch.clamp(output[:sample_size], -1., 1.).detach().cpu() |
|
|
save_output = ((save_output + 1) / 2) |
|
|
save_input = ((im[:sample_size] + 1) / 2).detach().cpu() |
|
|
|
|
|
grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size) |
|
|
img = torchvision.transforms.ToPILImage()(grid) |
|
|
if not os.path.exists(os.path.join(train_config.task_name,'vqvae_autoencoder_samples')): |
|
|
os.mkdir(os.path.join(train_config.task_name, 'vqvae_autoencoder_samples')) |
|
|
img.save(os.path.join(train_config.task_name,'vqvae_autoencoder_samples', |
|
|
'current_autoencoder_sample_{}.png'.format(img_save_count))) |
|
|
img_save_count += 1 |
|
|
img.close() |
|
|
|
|
|
|
|
|
|
|
|
recon_loss = recon_criterion(output, im) |
|
|
recon_losses.append(recon_loss.item()) |
|
|
recon_loss = recon_loss / acc_steps |
|
|
g_loss = (recon_loss + |
|
|
(train_config.codebook_weight * quantize_losses['codebook_loss'] / acc_steps) + |
|
|
(train_config.commitment_beta * quantize_losses['commitment_loss'] / acc_steps)) |
|
|
codebook_losses.append(train_config.codebook_weight * quantize_losses['codebook_loss'].item()) |
|
|
|
|
|
if step_count > disc_step_start: |
|
|
disc_fake_pred = discriminator(model_output[0]) |
|
|
disc_fake_loss = disc_criterion(disc_fake_pred, |
|
|
torch.ones(disc_fake_pred.shape, |
|
|
device=disc_fake_pred.device)) |
|
|
gen_losses.append(train_config.disc_weight * disc_fake_loss.item()) |
|
|
g_loss += train_config.disc_weight * disc_fake_loss / acc_steps |
|
|
lpips_loss = torch.mean(lpips_model(output, im)) / acc_steps |
|
|
perceptual_losses.append(train_config.perceptual_weight * lpips_loss.item()) |
|
|
g_loss += train_config.perceptual_weight*lpips_loss / acc_steps |
|
|
losses.append(g_loss.item()) |
|
|
g_loss.backward() |
|
|
|
|
|
|
|
|
|
|
|
if step_count > disc_step_start: |
|
|
fake = output |
|
|
disc_fake_pred = discriminator(fake.detach()) |
|
|
disc_real_pred = discriminator(im) |
|
|
disc_fake_loss = disc_criterion(disc_fake_pred, |
|
|
torch.zeros(disc_fake_pred.shape, |
|
|
device=disc_fake_pred.device)) |
|
|
disc_real_loss = disc_criterion(disc_real_pred, |
|
|
torch.ones(disc_real_pred.shape, |
|
|
device=disc_real_pred.device)) |
|
|
disc_loss = train_config.disc_weight * (disc_fake_loss + disc_real_loss) / 2 |
|
|
disc_losses.append(disc_loss.item()) |
|
|
disc_loss = disc_loss / acc_steps |
|
|
disc_loss.backward() |
|
|
if step_count % acc_steps == 0: |
|
|
optimizer_d.step() |
|
|
optimizer_d.zero_grad() |
|
|
|
|
|
|
|
|
if step_count % acc_steps == 0: |
|
|
optimizer_g.step() |
|
|
optimizer_g.zero_grad() |
|
|
optimizer_d.step() |
|
|
optimizer_d.zero_grad() |
|
|
optimizer_g.step() |
|
|
optimizer_g.zero_grad() |
|
|
if len(disc_losses) > 0: |
|
|
print( |
|
|
'Finished epoch: {}/{} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | ' |
|
|
'Codebook : {:.4f} | G Loss : {:.4f} | D Loss {:.4f}'. |
|
|
format(epoch_idx + 1, |
|
|
num_epochs, |
|
|
np.mean(recon_losses), |
|
|
np.mean(perceptual_losses), |
|
|
np.mean(codebook_losses), |
|
|
np.mean(gen_losses), |
|
|
np.mean(disc_losses))) |
|
|
else: |
|
|
print('Finished epoch: {}/{} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | Codebook : {:.4f}'. |
|
|
format(epoch_idx + 1, |
|
|
num_epochs, |
|
|
np.mean(recon_losses), |
|
|
np.mean(perceptual_losses), |
|
|
np.mean(codebook_losses))) |
|
|
|
|
|
torch.save(model.state_dict(), os.path.join(train_config.task_name, |
|
|
train_config.vqvae_autoencoder_ckpt_name)) |
|
|
torch.save(discriminator.state_dict(), os.path.join(train_config.task_name, |
|
|
train_config.vqvae_discriminator_ckpt_name)) |
|
|
print('Done Training...') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class Unet(nn.Module): |
|
|
r""" |
|
|
Unet model comprising |
|
|
Down blocks, Midblocks and Uplocks |
|
|
""" |
|
|
|
|
|
def __init__(self, im_channels, model_config): |
|
|
super().__init__() |
|
|
self.down_channels = model_config.down_channels |
|
|
self.mid_channels = model_config.mid_channels |
|
|
self.t_emb_dim = model_config.time_emb_dim |
|
|
self.down_sample = model_config.down_sample |
|
|
self.num_down_layers = model_config.num_down_layers |
|
|
self.num_mid_layers = model_config.num_mid_layers |
|
|
self.num_up_layers = model_config.num_up_layers |
|
|
self.attns = model_config.attn_down |
|
|
self.norm_channels = model_config.norm_channels |
|
|
self.num_heads = model_config.num_heads |
|
|
self.conv_out_channels = model_config.conv_out_channels |
|
|
|
|
|
assert self.mid_channels[0] == self.down_channels[-1] |
|
|
assert self.mid_channels[-1] == self.down_channels[-2] |
|
|
assert len(self.down_sample) == len(self.down_channels) - 1 |
|
|
assert len(self.attns) == len(self.down_channels) - 1 |
|
|
|
|
|
|
|
|
self.t_proj = nn.Sequential( |
|
|
nn.Linear(self.t_emb_dim, self.t_emb_dim), |
|
|
nn.SiLU(), |
|
|
nn.Linear(self.t_emb_dim, self.t_emb_dim) |
|
|
) |
|
|
|
|
|
self.up_sample = list(reversed(self.down_sample)) |
|
|
self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=1) |
|
|
|
|
|
self.downs = nn.ModuleList([]) |
|
|
for i in range(len(self.down_channels) - 1): |
|
|
self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i + 1], self.t_emb_dim, |
|
|
down_sample=self.down_sample[i], |
|
|
num_heads=self.num_heads, |
|
|
num_layers=self.num_down_layers, |
|
|
attn=self.attns[i], norm_channels=self.norm_channels)) |
|
|
|
|
|
self.mids = nn.ModuleList([]) |
|
|
for i in range(len(self.mid_channels) - 1): |
|
|
self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1], self.t_emb_dim, |
|
|
num_heads=self.num_heads, |
|
|
num_layers=self.num_mid_layers, |
|
|
norm_channels=self.norm_channels)) |
|
|
|
|
|
self.ups = nn.ModuleList([]) |
|
|
for i in reversed(range(len(self.down_channels) - 1)): |
|
|
self.ups.append(UpBlockUnet(self.down_channels[i] * 2, self.down_channels[i - 1] if i != 0 else self.conv_out_channels, |
|
|
self.t_emb_dim, up_sample=self.down_sample[i], |
|
|
num_heads=self.num_heads, |
|
|
num_layers=self.num_up_layers, |
|
|
norm_channels=self.norm_channels)) |
|
|
|
|
|
self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels) |
|
|
self.conv_out = nn.Conv2d(self.conv_out_channels, im_channels, kernel_size=3, padding=1) |
|
|
|
|
|
def forward(self, x, t): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out = self.conv_in(x) |
|
|
|
|
|
|
|
|
|
|
|
t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim) |
|
|
t_emb = self.t_proj(t_emb) |
|
|
|
|
|
down_outs = [] |
|
|
|
|
|
for idx, down in enumerate(self.downs): |
|
|
down_outs.append(out) |
|
|
out = down(out, t_emb) |
|
|
|
|
|
|
|
|
|
|
|
for mid in self.mids: |
|
|
out = mid(out, t_emb) |
|
|
|
|
|
|
|
|
for up in self.ups: |
|
|
down_out = down_outs.pop() |
|
|
out = up(out, down_out, t_emb) |
|
|
|
|
|
out = self.norm_out(out) |
|
|
out = nn.SiLU()(out) |
|
|
out = self.conv_out(out) |
|
|
|
|
|
return out |
|
|
|
|
|
def trainLDM(Config): |
|
|
|
|
|
diffusion_config = Config.diffusion_params |
|
|
dataset_config = Config.dataset_params |
|
|
diffusion_model_config = Config.ldm_params |
|
|
autoencoder_model_config = Config.autoencoder_params |
|
|
train_config = Config.train_params |
|
|
|
|
|
|
|
|
scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config.num_timesteps, |
|
|
beta_start=diffusion_config.beta_start, |
|
|
beta_end=diffusion_config.beta_end) |
|
|
|
|
|
|
|
|
im_dataset_cls = { |
|
|
'mnist': MnistDataset, |
|
|
'celebA': CelebDataset, |
|
|
'animeface': AnimeFaceDataset, |
|
|
'celebAhair': CelebHairDataset |
|
|
}.get(dataset_config.name) |
|
|
|
|
|
im_dataset = im_dataset_cls(split='train', |
|
|
im_path=dataset_config.im_path, |
|
|
im_size=dataset_config.im_size, |
|
|
im_channels=dataset_config.im_channels, |
|
|
use_latents=True, |
|
|
latent_path=os.path.join(train_config.task_name, |
|
|
train_config.vqvae_latent_dir_name) |
|
|
) |
|
|
|
|
|
data_loader = DataLoader(im_dataset, |
|
|
batch_size=train_config.ldm_batch_size, |
|
|
shuffle=True, |
|
|
num_workers=os.cpu_count(), |
|
|
pin_memory=True, |
|
|
drop_last=False, |
|
|
persistent_workers=True, pin_memory_device=device) |
|
|
|
|
|
|
|
|
model = Unet(im_channels=autoencoder_model_config.z_channels, |
|
|
model_config=diffusion_model_config).to(device) |
|
|
if os.path.exists(os.path.join(train_config.task_name, train_config.ldm_ckpt_name)): |
|
|
print('Loaded ldm checkpoint') |
|
|
model.load_state_dict(torch.load(os.path.join(train_config.task_name, train_config.ldm_ckpt_name), map_location=device, weights_only=True)) |
|
|
model.train() |
|
|
|
|
|
|
|
|
if not im_dataset.use_latents: |
|
|
print('Loading vqvae model as latents not present') |
|
|
vae = VQVAE(im_channels=dataset_config.im_channels, |
|
|
model_config=autoencoder_model_config).to(device) |
|
|
vae.eval() |
|
|
|
|
|
if os.path.exists(os.path.join(train_config.task_name, |
|
|
train_config.vqvae_autoencoder_ckpt_name)): |
|
|
print('Loaded vae checkpoint') |
|
|
vae.load_state_dict(torch.load(os.path.join(train_config.task_name, |
|
|
train_config.vqvae_autoencoder_ckpt_name), |
|
|
map_location=device)) |
|
|
|
|
|
num_epochs = train_config.ldm_epochs |
|
|
optimizer = Adam(model.parameters(), lr=train_config.ldm_lr) |
|
|
criterion = torch.nn.MSELoss() |
|
|
|
|
|
|
|
|
if not im_dataset.use_latents: |
|
|
for param in vae.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
for epoch_idx in range(num_epochs): |
|
|
losses = [] |
|
|
for im in tqdm(data_loader): |
|
|
optimizer.zero_grad() |
|
|
im = im.float().to(device) |
|
|
if not im_dataset.use_latents: |
|
|
with torch.no_grad(): |
|
|
im, _ = vae.encode(im) |
|
|
|
|
|
|
|
|
noise = torch.randn_like(im).to(device) |
|
|
|
|
|
|
|
|
t = torch.randint(0, diffusion_config.num_timesteps, (im.shape[0],)).to(device) |
|
|
|
|
|
|
|
|
noisy_im = scheduler.add_noise(im, noise, t) |
|
|
noise_pred = model(noisy_im, t) |
|
|
|
|
|
loss = criterion(noise_pred, noise) |
|
|
losses.append(loss.item()) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
print(f'Finished epoch:{epoch_idx + 1}/{num_epochs} | Loss : {np.mean(losses):.4f}') |
|
|
|
|
|
torch.save(model.state_dict(), os.path.join(train_config.task_name, |
|
|
train_config.ldm_ckpt_name)) |
|
|
|
|
|
|
|
|
infer(Config) |
|
|
|
|
|
|
|
|
train_continue = yaml.safe_load(open("/home/taruntejaneurips23/Ashish/DDPM/_5_ldm_celeba.yaml", 'r')) |
|
|
train_continue = DotDict.from_dict(train_continue) |
|
|
if train_continue.training._continue_ == False: |
|
|
print('Training Stoped ...') |
|
|
break |
|
|
|
|
|
print('Done Training ...') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sample(model, scheduler, train_config, diffusion_model_config, |
|
|
autoencoder_model_config, diffusion_config, dataset_config, vae): |
|
|
r""" |
|
|
Sample stepwise by going backward one timestep at a time. |
|
|
We save the x0 predictions |
|
|
""" |
|
|
im_size = dataset_config.im_size // 2**sum(autoencoder_model_config.down_sample) |
|
|
xt = torch.randn((train_config.num_samples, |
|
|
autoencoder_model_config.z_channels, |
|
|
im_size, |
|
|
im_size)).to(device) |
|
|
|
|
|
save_count = 0 |
|
|
for i in tqdm(reversed(range(diffusion_config.num_timesteps)), total=diffusion_config.num_timesteps): |
|
|
|
|
|
noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device)) |
|
|
|
|
|
|
|
|
xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device)) |
|
|
|
|
|
|
|
|
|
|
|
if i == 0: |
|
|
|
|
|
ims = vae.decode(xt) |
|
|
else: |
|
|
ims = xt |
|
|
|
|
|
ims = torch.clamp(ims, -1., 1.).detach().cpu() |
|
|
ims = (ims + 1) / 2 |
|
|
grid = make_grid(ims, nrow=train_config.num_grid_rows) |
|
|
img = torchvision.transforms.ToPILImage()(grid) |
|
|
|
|
|
if not os.path.exists(os.path.join(train_config.task_name, 'samples')): |
|
|
os.mkdir(os.path.join(train_config.task_name, 'samples')) |
|
|
img.save(os.path.join(train_config.task_name, 'samples', 'x0_{}.png'.format(i))) |
|
|
img.close() |
|
|
|
|
|
|
|
|
def infer(Config): |
|
|
|
|
|
diffusion_config = Config.diffusion_params |
|
|
dataset_config = Config.dataset_params |
|
|
diffusion_model_config = Config.ldm_params |
|
|
autoencoder_model_config = Config.autoencoder_params |
|
|
train_config = Config.train_params |
|
|
|
|
|
|
|
|
scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config.num_timesteps, |
|
|
beta_start=diffusion_config.beta_start, |
|
|
beta_end=diffusion_config.beta_end) |
|
|
|
|
|
|
|
|
model = Unet(im_channels=autoencoder_model_config.z_channels, |
|
|
model_config=diffusion_model_config).to(device) |
|
|
model.eval() |
|
|
if os.path.exists(os.path.join(train_config.task_name, |
|
|
train_config.ldm_ckpt_name)): |
|
|
print('Loaded unet checkpoint') |
|
|
model.load_state_dict(torch.load(os.path.join(train_config.task_name, |
|
|
train_config.ldm_ckpt_name), |
|
|
map_location=device)) |
|
|
|
|
|
if not os.path.exists(train_config.task_name): |
|
|
os.mkdir(train_config.task_name) |
|
|
|
|
|
vae = VQVAE(im_channels=dataset_config.im_channels, |
|
|
model_config=autoencoder_model_config).to(device) |
|
|
vae.eval() |
|
|
|
|
|
|
|
|
if os.path.exists(os.path.join(train_config.task_name, |
|
|
train_config.vqvae_autoencoder_ckpt_name)): |
|
|
print('Loaded vae checkpoint') |
|
|
vae.load_state_dict(torch.load(os.path.join(train_config.task_name, |
|
|
train_config.vqvae_autoencoder_ckpt_name), |
|
|
map_location=device), strict=True) |
|
|
with torch.no_grad(): |
|
|
sample(model, scheduler, train_config, diffusion_model_config, |
|
|
autoencoder_model_config, diffusion_config, dataset_config, vae) |
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
|
|
|
def get_args(): |
|
|
parser = argparse.ArgumentParser(description="Choose between train VAE, train LDM, or infer mode.") |
|
|
parser.add_argument('--mode', choices=['train_vae', 'train_ldm', 'infer'], default='infer', |
|
|
help="Mode to run: train_vae, train_ldm, or infer") |
|
|
return parser.parse_args() |
|
|
|
|
|
args = get_args() |
|
|
|
|
|
if args.mode == 'train_vae': |
|
|
trainVAE(Config) |
|
|
elif args.mode == 'train_ldm': |
|
|
trainLDM(Config) |
|
|
else: |
|
|
infer(Config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint( |
|
|
total_steps, epoch, model, discriminator, |
|
|
optimizer_d, optimizer_g, loss, checkpoint_path |
|
|
): |
|
|
checkpoint = { |
|
|
"total_steps": total_steps, |
|
|
"epoch": epoch, |
|
|
"model_state_dict": model.state_dict(), |
|
|
"discriminator_state_dict": discriminator.state_dict(), |
|
|
"optimizer_d_state_dict": optimizer_d.state_dict(), |
|
|
"optimizer_g_state_dict": optimizer_g.state_dict(), |
|
|
"loss": loss, |
|
|
} |
|
|
torch.save(checkpoint, checkpoint_path) |
|
|
print(f"Checkpoint saved after {total_steps} steps at epoch {epoch}") |
|
|
|
|
|
|
|
|
def load_checkpoint( |
|
|
checkpoint_path, model, discriminator, optimizer_d, optimizer_g |
|
|
): |
|
|
if os.path.exists(checkpoint_path): |
|
|
checkpoint = torch.load(checkpoint_path) |
|
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
discriminator.load_state_dict(checkpoint["discriminator_state_dict"]) |
|
|
optimizer_d.load_state_dict(checkpoint["optimizer_d_state_dict"]) |
|
|
optimizer_g.load_state_dict(checkpoint["optimizer_g_state_dict"]) |
|
|
total_steps = checkpoint["total_steps"] |
|
|
start_epoch = checkpoint["epoch"] + 1 |
|
|
loss = checkpoint["loss"] |
|
|
print(f"Checkpoint loaded. Resuming from epoch {start_epoch}") |
|
|
return total_steps, start_epoch, loss |
|
|
else: |
|
|
print("No checkpoint found. Starting from scratch.") |
|
|
return 0, 0, None |
|
|
|
|
|
|
|
|
def trainVAE(Config, dataloader): |
|
|
""" |
|
|
Trains a VQVAE model using the provided configuration and data loader. |
|
|
""" |
|
|
|
|
|
dataset_config = Config.dataset_params |
|
|
autoencoder_config = Config.autoencoder_params |
|
|
train_config = Config.train_params |
|
|
|
|
|
seed = train_config.seed |
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
random.seed(seed) |
|
|
if device == "cuda": |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
|
|
|
model = VQVAE(im_channels=dataset_config.im_channels, model_config=autoencoder_config).to(device) |
|
|
discriminator = Discriminator(im_channels=dataset_config.im_channels).to(device) |
|
|
|
|
|
|
|
|
checkpoint_path = os.path.join(train_config.task_name, "vqvae_checkpoint.pth") |
|
|
total_steps, start_epoch, _ = load_checkpoint(checkpoint_path, model, discriminator, None, None) |
|
|
|
|
|
|
|
|
recon_criterion = torch.nn.MSELoss() |
|
|
lpips_model = LPIPS().eval().to(device) |
|
|
disc_criterion = torch.nn.MSELoss() |
|
|
|
|
|
|
|
|
optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999)) |
|
|
optimizer_g = torch.optim.AdamW(model.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999)) |
|
|
|
|
|
num_epochs = train_config.autoencoder_epochs |
|
|
acc_steps = train_config.autoencoder_acc_steps |
|
|
image_save_steps = train_config.autoencoder_img_save_steps |
|
|
img_save_count = 0 |
|
|
|
|
|
|
|
|
os.makedirs(os.path.join(train_config.task_name, "vqvae_autoencoder_samples"), exist_ok=True) |
|
|
|
|
|
|
|
|
for epoch_idx in range(start_epoch, num_epochs): |
|
|
recon_losses, codebook_losses, perceptual_losses, disc_losses, gen_losses = [], [], [], [], [] |
|
|
|
|
|
for images in dataloader: |
|
|
total_steps += 1 |
|
|
images = images.to(device) |
|
|
|
|
|
|
|
|
model_output = model(images) |
|
|
output, z, quantize_losses = model_output |
|
|
|
|
|
|
|
|
if total_steps % image_save_steps == 0 or total_steps == 1: |
|
|
sample_size = min(8, images.shape[0]) |
|
|
save_output = torch.clamp(output[:sample_size], -1.0, 1.0).detach().cpu() |
|
|
save_output = (save_output + 1) / 2 |
|
|
save_input = ((images[:sample_size] + 1) / 2).detach().cpu() |
|
|
|
|
|
grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size) |
|
|
img = tv.transforms.ToPILImage()(grid) |
|
|
img.save( |
|
|
os.path.join( |
|
|
train_config.task_name, |
|
|
"vqvae_autoencoder_samples", |
|
|
f"current_autoencoder_sample_{img_save_count}.png", |
|
|
) |
|
|
) |
|
|
img_save_count += 1 |
|
|
img.close() |
|
|
|
|
|
|
|
|
recon_loss = recon_criterion(output, images) / acc_steps |
|
|
recon_losses.append(recon_loss.item()) |
|
|
|
|
|
|
|
|
codebook_loss = train_config.codebook_weight * quantize_losses["codebook_loss"] / acc_steps |
|
|
perceptual_loss = train_config.perceptual_weight * lpips_model(output, images).mean() / acc_steps |
|
|
g_loss = recon_loss + codebook_loss + perceptual_loss |
|
|
|
|
|
if total_steps > train_config.disc_start: |
|
|
disc_fake_pred = discriminator(output) |
|
|
gen_loss = train_config.disc_weight * disc_criterion( |
|
|
disc_fake_pred, torch.ones_like(disc_fake_pred) |
|
|
) / acc_steps |
|
|
g_loss += gen_loss |
|
|
gen_losses.append(gen_loss.item()) |
|
|
|
|
|
g_loss.backward() |
|
|
optimizer_g.step() |
|
|
optimizer_g.zero_grad() |
|
|
|
|
|
|
|
|
if total_steps > train_config.disc_start: |
|
|
disc_fake_pred = discriminator(output.detach()) |
|
|
disc_real_pred = discriminator(images) |
|
|
disc_fake_loss = disc_criterion( |
|
|
disc_fake_pred, torch.zeros_like(disc_fake_pred) |
|
|
) / acc_steps |
|
|
disc_real_loss = disc_criterion( |
|
|
disc_real_pred, torch.ones_like(disc_real_pred) |
|
|
) / acc_steps |
|
|
disc_loss = train_config.disc_weight * (disc_fake_loss + disc_real_loss) / 2 |
|
|
disc_loss.backward() |
|
|
optimizer_d.step() |
|
|
optimizer_d.zero_grad() |
|
|
disc_losses.append(disc_loss.item()) |
|
|
|
|
|
|
|
|
save_checkpoint(total_steps, epoch_idx, model, discriminator, optimizer_d, optimizer_g, recon_losses, checkpoint_path) |
|
|
|
|
|
|
|
|
print( |
|
|
f"Epoch {epoch_idx + 1}/{num_epochs} | Recon Loss: {np.mean(recon_losses):.4f} | " |
|
|
f"Perceptual Loss: {np.mean(perceptual_losses):.4f} | Codebook Loss: {np.mean(codebook_losses):.4f} | " |
|
|
f"G Loss: {np.mean(gen_losses):.4f} | D Loss: {np.mean(disc_losses):.4f}" |
|
|
) |
|
|
|