alpha31476's picture
LDM-train-pass, checking results
87ef7b5 verified
# ==================================================================
# L A T E N T D I F F U S I O N M O D E L
# ==================================================================
# Author : Ashish Kumar Uchadiya
# Created : November 3, 2024
# Description: This script implements a Latent Diffusion Model using
# a cosine or linear noise scheduling approach for high-resolution
# image generation. The model leverages generative techniques to
# learn a latent representation and progressively reduce noise to
# generate clear, realistic images.
# ==================================================================
# I M P O R T S
# ==================================================================
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
"""Lpips"""
# from __future__ import absolute_import
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Variable
import numpy as np
import torch.nn
import torchvision
# Taken from https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def spatial_average(in_tens, keepdim=True):
return in_tens.mean([2, 3], keepdim=keepdim)
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
vgg_pretrained_features = torchvision.models.vgg16(
weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1
).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
# Freeze vgg model
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
# Return output of vgg features
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
# Learned perceptual metric
class LPIPS(nn.Module):
def __init__(self, net='vgg', version='0.1', use_dropout=True):
super(LPIPS, self).__init__()
self.version = version
# Imagenet normalization
self.scaling_layer = ScalingLayer()
########################
# Instantiate vgg model
self.chns = [64, 128, 256, 512, 512]
self.L = len(self.chns)
self.net = vgg16(pretrained=True, requires_grad=False)
# Add 1x1 convolutional Layers
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
self.lins = nn.ModuleList(self.lins)
########################
# Load the weights of trained LPIPS model
import inspect
import os
# /home/taruntejaneurips23/.cache/torch/hub/checkpoints/vgg16-397923af.pth
print(os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net))))
# model_path = os.path.abspath(
# os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net)))
# print('Loading model from: %s' % model_path)
# self.load_state_dict(torch.load(model_path, map_location=device), strict=False)
########################
# Freeze all parameters
self.eval()
for param in self.parameters():
param.requires_grad = False
########################
def forward(self, in0, in1, normalize=False):
# Scale the inputs to -1 to +1 range if needed
if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
in0 = 2 * in0 - 1
in1 = 2 * in1 - 1
########################
# Normalize the inputs according to imagenet normalization
in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1)
########################
# Get VGG outputs for image0 and image1
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
feats0, feats1, diffs = {}, {}, {}
########################
# Compute Square of Difference for each layer output
for kk in range(self.L):
feats0[kk], feats1[kk] = torch.nn.functional.normalize(outs0[kk], dim=1), torch.nn.functional.normalize(
outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
########################
# 1x1 convolution followed by spatial average on the square differences
res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
val = 0
# Aggregate the results of each layer
for l in range(self.L):
val += res[l]
return val
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
# Imagnet normalization for (0-1)
# mean = [0.485, 0.456, 0.406]
# std = [0.229, 0.224, 0.225]
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
def forward(self, inp):
return (inp - self.shift) / self.scale
class NetLinLayer(nn.Module):
''' A single linear layer which does a 1x1 conv '''
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = [nn.Dropout(), ] if (use_dropout) else []
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
self.model = nn.Sequential(*layers)
def forward(self, x):
out = self.model(x)
return out
"""Blocks"""
import torch
import numpy as np
class LinearNoiseScheduler:
r"""
Class for the linear noise scheduler that is used in DDPM.
"""
def __init__(self, num_timesteps, beta_start, beta_end):
self.num_timesteps = num_timesteps
self.beta_start = beta_start
self.beta_end = beta_end
# Mimicking how compvis repo creates schedule
self.betas = (
torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_timesteps) ** 2
)
self.alphas = 1. - self.betas
self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)
def add_noise(self, original, noise, t):
r"""
Forward method for diffusion
:param original: Image on which noise is to be applied
:param noise: Random Noise Tensor (from normal dist)
:param t: timestep of the forward process of shape -> (B,)
:return:
"""
original_shape = original.shape
batch_size = original_shape[0]
sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
# Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W)
for _ in range(len(original_shape) - 1):
sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)
for _ in range(len(original_shape) - 1):
sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)
# Apply and Return Forward process equation
return (sqrt_alpha_cum_prod.to(original.device) * original
+ sqrt_one_minus_alpha_cum_prod.to(original.device) * noise)
def sample_prev_timestep(self, xt, noise_pred, t):
r"""
Use the noise prediction by model to get
xt-1 using xt and the nosie predicted
:param xt: current timestep sample
:param noise_pred: model noise prediction
:param t: current timestep we are at
:return:
"""
x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) /
torch.sqrt(self.alpha_cum_prod.to(xt.device)[t]))
x0 = torch.clamp(x0, -1., 1.)
mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t])
mean = mean / torch.sqrt(self.alphas.to(xt.device)[t])
if t == 0:
return mean, x0
else:
variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])
variance = variance * self.betas.to(xt.device)[t]
sigma = variance ** 0.5
z = torch.randn(xt.shape).to(xt.device)
# OR
# variance = self.betas[t]
# sigma = variance ** 0.5
# z = torch.randn(xt.shape).to(xt.device)
return mean + sigma * z, x0
import torch
import math
class CosineNoiseScheduler:
r"""
Class for the cosine noise scheduler, often used in DDPM-based models.
"""
def __init__(self, num_timesteps, s=0.008):
self.num_timesteps = num_timesteps
self.s = s
# Cosine schedule based on paper
def cosine_schedule(t):
return math.cos((t / self.num_timesteps + s) / (1 + s) * math.pi / 2) ** 2
# Compute alphas
self.alphas = torch.tensor([cosine_schedule(t) for t in range(num_timesteps)])
self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)
def add_noise(self, original, noise, t):
original_shape = original.shape
batch_size = original_shape[0]
sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
for _ in range(len(original_shape) - 1):
sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)
for _ in range(len(original_shape) - 1):
sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)
return (sqrt_alpha_cum_prod * original + sqrt_one_minus_alpha_cum_prod * noise)
def sample_prev_timestep(self, xt, noise_pred, t):
x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) /
torch.sqrt(self.alpha_cum_prod.to(xt.device)[t]))
x0 = torch.clamp(x0, -1., 1.)
mean = xt - ((1 - self.alphas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t])
mean = mean / torch.sqrt(self.alphas.to(xt.device)[t])
if t == 0:
return mean, x0
else:
variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])
variance = variance * (1 - self.alphas.to(xt.device)[t])
sigma = variance ** 0.5
z = torch.randn(xt.shape).to(xt.device)
return mean + sigma * z, x0
import torch
import torch.nn as nn
def get_time_embedding(time_steps, temb_dim):
r"""
Convert time steps tensor into an embedding using the
sinusoidal time embedding formula
:param time_steps: 1D tensor of length batch size
:param temb_dim: Dimension of the embedding
:return: BxD embedding representation of B time steps
"""
assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
# factor = 10000^(2i/d_model)
factor = 10000 ** ((torch.arange(
start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
)
# pos / factor
# timesteps B -> B, 1 -> B, temb_dim
t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
return t_emb
class DownBlock(nn.Module):
r"""
Down conv block with attention.
Sequence of following block
1. Resnet block with time embedding
2. Attention block
3. Downsample
"""
def __init__(self, in_channels, out_channels, t_emb_dim,
down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None):
super().__init__()
self.num_layers = num_layers
self.down_sample = down_sample
self.attn = attn
self.context_dim = context_dim
self.cross_attn = cross_attn
self.t_emb_dim = t_emb_dim
self.resnet_conv_first = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
nn.SiLU(),
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,
kernel_size=3, stride=1, padding=1),
)
for i in range(num_layers)
]
)
if self.t_emb_dim is not None:
self.t_emb_layers = nn.ModuleList([
nn.Sequential(
nn.SiLU(),
nn.Linear(self.t_emb_dim, out_channels)
)
for _ in range(num_layers)
])
self.resnet_conv_second = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels,
kernel_size=3, stride=1, padding=1),
)
for _ in range(num_layers)
]
)
if self.attn:
self.attention_norms = nn.ModuleList(
[nn.GroupNorm(norm_channels, out_channels)
for _ in range(num_layers)]
)
self.attentions = nn.ModuleList(
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)]
)
if self.cross_attn:
assert context_dim is not None, "Context Dimension must be passed for cross attention"
self.cross_attention_norms = nn.ModuleList(
[nn.GroupNorm(norm_channels, out_channels)
for _ in range(num_layers)]
)
self.cross_attentions = nn.ModuleList(
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)]
)
self.context_proj = nn.ModuleList(
[nn.Linear(context_dim, out_channels)
for _ in range(num_layers)]
)
self.residual_input_conv = nn.ModuleList(
[
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
for i in range(num_layers)
]
)
self.down_sample_conv = nn.Conv2d(out_channels, out_channels,
4, 2, 1) if self.down_sample else nn.Identity()
def forward(self, x, t_emb=None, context=None):
out = x
for i in range(self.num_layers):
# Resnet block of Unet
resnet_input = out
out = self.resnet_conv_first[i](out)
if self.t_emb_dim is not None:
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
out = self.resnet_conv_second[i](out)
out = out + self.residual_input_conv[i](resnet_input)
if self.attn:
# Attention block of Unet
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
if self.cross_attn:
assert context is not None, "context cannot be None if cross attention layers are used"
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.cross_attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
context_proj = self.context_proj[i](context)
out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
# Downsample
out = self.down_sample_conv(out)
return out
class MidBlock(nn.Module):
r"""
Mid conv block with attention.
Sequence of following blocks
1. Resnet block with time embedding
2. Attention block
3. Resnet block with time embedding
"""
def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, context_dim=None):
super().__init__()
self.num_layers = num_layers
self.t_emb_dim = t_emb_dim
self.context_dim = context_dim
self.cross_attn = cross_attn
self.resnet_conv_first = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
nn.SiLU(),
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
padding=1),
)
for i in range(num_layers + 1)
]
)
if self.t_emb_dim is not None:
self.t_emb_layers = nn.ModuleList([
nn.Sequential(
nn.SiLU(),
nn.Linear(t_emb_dim, out_channels)
)
for _ in range(num_layers + 1)
])
self.resnet_conv_second = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
)
for _ in range(num_layers + 1)
]
)
self.attention_norms = nn.ModuleList(
[nn.GroupNorm(norm_channels, out_channels)
for _ in range(num_layers)]
)
self.attentions = nn.ModuleList(
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)]
)
if self.cross_attn:
assert context_dim is not None, "Context Dimension must be passed for cross attention"
self.cross_attention_norms = nn.ModuleList(
[nn.GroupNorm(norm_channels, out_channels)
for _ in range(num_layers)]
)
self.cross_attentions = nn.ModuleList(
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)]
)
self.context_proj = nn.ModuleList(
[nn.Linear(context_dim, out_channels)
for _ in range(num_layers)]
)
self.residual_input_conv = nn.ModuleList(
[
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
for i in range(num_layers + 1)
]
)
def forward(self, x, t_emb=None, context=None):
out = x
# First resnet block
resnet_input = out
out = self.resnet_conv_first[0](out)
if self.t_emb_dim is not None:
out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
out = self.resnet_conv_second[0](out)
out = out + self.residual_input_conv[0](resnet_input)
for i in range(self.num_layers):
# Attention Block
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
if self.cross_attn:
assert context is not None, "context cannot be None if cross attention layers are used"
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.cross_attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim
context_proj = self.context_proj[i](context)
out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
# Resnet Block
resnet_input = out
out = self.resnet_conv_first[i + 1](out)
if self.t_emb_dim is not None:
out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None]
out = self.resnet_conv_second[i + 1](out)
out = out + self.residual_input_conv[i + 1](resnet_input)
return out
class UpBlock(nn.Module):
r"""
Up conv block with attention.
Sequence of following blocks
1. Upsample
1. Concatenate Down block output
2. Resnet block with time embedding
3. Attention Block
"""
def __init__(self, in_channels, out_channels, t_emb_dim,
up_sample, num_heads, num_layers, attn, norm_channels):
super().__init__()
self.num_layers = num_layers
self.up_sample = up_sample
self.t_emb_dim = t_emb_dim
self.attn = attn
self.resnet_conv_first = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
nn.SiLU(),
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
padding=1),
)
for i in range(num_layers)
]
)
if self.t_emb_dim is not None:
self.t_emb_layers = nn.ModuleList([
nn.Sequential(
nn.SiLU(),
nn.Linear(t_emb_dim, out_channels)
)
for _ in range(num_layers)
])
self.resnet_conv_second = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
)
for _ in range(num_layers)
]
)
if self.attn:
self.attention_norms = nn.ModuleList(
[
nn.GroupNorm(norm_channels, out_channels)
for _ in range(num_layers)
]
)
self.attentions = nn.ModuleList(
[
nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)
]
)
self.residual_input_conv = nn.ModuleList(
[
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
for i in range(num_layers)
]
)
self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels,
4, 2, 1) \
if self.up_sample else nn.Identity()
def forward(self, x, out_down=None, t_emb=None):
# Upsample
x = self.up_sample_conv(x)
# Concat with Downblock output
if out_down is not None:
x = torch.cat([x, out_down], dim=1)
out = x
for i in range(self.num_layers):
# Resnet Block
resnet_input = out
out = self.resnet_conv_first[i](out)
if self.t_emb_dim is not None:
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
out = self.resnet_conv_second[i](out)
out = out + self.residual_input_conv[i](resnet_input)
# Self Attention
if self.attn:
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
return out
class UpBlockUnet(nn.Module):
r"""
Up conv block with attention.
Sequence of following blocks
1. Upsample
1. Concatenate Down block output
2. Resnet block with time embedding
3. Attention Block
"""
def __init__(self, in_channels, out_channels, t_emb_dim, up_sample,
num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None):
super().__init__()
self.num_layers = num_layers
self.up_sample = up_sample
self.t_emb_dim = t_emb_dim
self.cross_attn = cross_attn
self.context_dim = context_dim
self.resnet_conv_first = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
nn.SiLU(),
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
padding=1),
)
for i in range(num_layers)
]
)
if self.t_emb_dim is not None:
self.t_emb_layers = nn.ModuleList([
nn.Sequential(
nn.SiLU(),
nn.Linear(t_emb_dim, out_channels)
)
for _ in range(num_layers)
])
self.resnet_conv_second = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(norm_channels, out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
)
for _ in range(num_layers)
]
)
self.attention_norms = nn.ModuleList(
[
nn.GroupNorm(norm_channels, out_channels)
for _ in range(num_layers)
]
)
self.attentions = nn.ModuleList(
[
nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)
]
)
if self.cross_attn:
assert context_dim is not None, "Context Dimension must be passed for cross attention"
self.cross_attention_norms = nn.ModuleList(
[nn.GroupNorm(norm_channels, out_channels)
for _ in range(num_layers)]
)
self.cross_attentions = nn.ModuleList(
[nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
for _ in range(num_layers)]
)
self.context_proj = nn.ModuleList(
[nn.Linear(context_dim, out_channels)
for _ in range(num_layers)]
)
self.residual_input_conv = nn.ModuleList(
[
nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
for i in range(num_layers)
]
)
self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
4, 2, 1) \
if self.up_sample else nn.Identity()
def forward(self, x, out_down=None, t_emb=None, context=None):
x = self.up_sample_conv(x)
if out_down is not None:
x = torch.cat([x, out_down], dim=1)
out = x
for i in range(self.num_layers):
# Resnet
resnet_input = out
out = self.resnet_conv_first[i](out)
if self.t_emb_dim is not None:
out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
out = self.resnet_conv_second[i](out)
out = out + self.residual_input_conv[i](resnet_input)
# Self Attention
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
# Cross Attention
if self.cross_attn:
assert context is not None, "context cannot be None if cross attention layers are used"
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h * w)
in_attn = self.cross_attention_norms[i](in_attn)
in_attn = in_attn.transpose(1, 2)
assert len(context.shape) == 3, \
"Context shape does not match B,_,CONTEXT_DIM"
assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim,\
"Context shape does not match B,_,CONTEXT_DIM"
context_proj = self.context_proj[i](context)
out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj)
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
out = out + out_attn
return out
"""Vqvae"""
import torch
import torch.nn as nn
class VQVAE(nn.Module):
def __init__(self, im_channels, model_config):
super().__init__()
self.down_channels = model_config.down_channels
self.mid_channels = model_config.mid_channels
self.down_sample = model_config.down_sample
self.num_down_layers = model_config.num_down_layers
self.num_mid_layers = model_config.num_mid_layers
self.num_up_layers = model_config.num_up_layers
# To disable attention in Downblock of Encoder and Upblock of Decoder
self.attns = model_config.attn_down
# Latent Dimension
self.z_channels = model_config.z_channels
self.codebook_size = model_config.codebook_size
self.norm_channels = model_config.norm_channels
self.num_heads = model_config.num_heads
# Assertion to validate the channel information
assert self.mid_channels[0] == self.down_channels[-1]
assert self.mid_channels[-1] == self.down_channels[-1]
assert len(self.down_sample) == len(self.down_channels) - 1
assert len(self.attns) == len(self.down_channels) - 1
# Wherever we use downsampling in encoder correspondingly use
# upsampling in decoder
self.up_sample = list(reversed(self.down_sample))
##################### Encoder ######################
self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))
# Downblock + Midblock
self.encoder_layers = nn.ModuleList([])
for i in range(len(self.down_channels) - 1):
self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1],
t_emb_dim=None, down_sample=self.down_sample[i],
num_heads=self.num_heads,
num_layers=self.num_down_layers,
attn=self.attns[i],
norm_channels=self.norm_channels))
self.encoder_mids = nn.ModuleList([])
for i in range(len(self.mid_channels) - 1):
self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1],
t_emb_dim=None,
num_heads=self.num_heads,
num_layers=self.num_mid_layers,
norm_channels=self.norm_channels))
self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1])
self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], self.z_channels, kernel_size=3, padding=1)
# Pre Quantization Convolution
self.pre_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
# Codebook
self.embedding = nn.Embedding(self.codebook_size, self.z_channels)
####################################################
##################### Decoder ######################
# Post Quantization Convolution
self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1))
# Midblock + Upblock
self.decoder_mids = nn.ModuleList([])
for i in reversed(range(1, len(self.mid_channels))):
self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1],
t_emb_dim=None,
num_heads=self.num_heads,
num_layers=self.num_mid_layers,
norm_channels=self.norm_channels))
self.decoder_layers = nn.ModuleList([])
for i in reversed(range(1, len(self.down_channels))):
self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1],
t_emb_dim=None, up_sample=self.down_sample[i - 1],
num_heads=self.num_heads,
num_layers=self.num_up_layers,
attn=self.attns[i-1],
norm_channels=self.norm_channels))
self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0])
self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1)
def quantize(self, x):
B, C, H, W = x.shape
# B, C, H, W -> B, H, W, C
x = x.permute(0, 2, 3, 1)
# B, H, W, C -> B, H*W, C
x = x.reshape(x.size(0), -1, x.size(-1))
# Find nearest embedding/codebook vector
# dist between (B, H*W, C) and (B, K, C) -> (B, H*W, K)
dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1)))
# (B, H*W)
min_encoding_indices = torch.argmin(dist, dim=-1)
# Replace encoder output with nearest codebook
# quant_out -> B*H*W, C
quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1))
# x -> B*H*W, C
x = x.reshape((-1, x.size(-1)))
commmitment_loss = torch.mean((quant_out.detach() - x) ** 2)
codebook_loss = torch.mean((quant_out - x.detach()) ** 2)
quantize_losses = {
'codebook_loss': codebook_loss,
'commitment_loss': commmitment_loss
}
# Straight through estimation
quant_out = x + (quant_out - x).detach()
# quant_out -> B, C, H, W
quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2)
min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1)))
return quant_out, quantize_losses, min_encoding_indices
def encode(self, x):
out = self.encoder_conv_in(x)
for idx, down in enumerate(self.encoder_layers):
out = down(out)
for mid in self.encoder_mids:
out = mid(out)
out = self.encoder_norm_out(out)
out = nn.SiLU()(out)
out = self.encoder_conv_out(out)
out = self.pre_quant_conv(out)
out, quant_losses, _ = self.quantize(out)
return out, quant_losses
def decode(self, z):
out = z
out = self.post_quant_conv(out)
out = self.decoder_conv_in(out)
for mid in self.decoder_mids:
out = mid(out)
for idx, up in enumerate(self.decoder_layers):
out = up(out)
out = self.decoder_norm_out(out)
out = nn.SiLU()(out)
out = self.decoder_conv_out(out)
return out
def forward(self, x):
z, quant_losses = self.encode(x)
out = self.decode(z)
return out, z, quant_losses
"""Vae"""
import torch
import torch.nn as nn
class VAE(nn.Module):
def __init__(self, im_channels, model_config):
super().__init__()
self.down_channels = model_config['down_channels']
self.mid_channels = model_config['mid_channels']
self.down_sample = model_config['down_sample']
self.num_down_layers = model_config['num_down_layers']
self.num_mid_layers = model_config['num_mid_layers']
self.num_up_layers = model_config['num_up_layers']
# To disable attention in Downblock of Encoder and Upblock of Decoder
self.attns = model_config['attn_down']
# Latent Dimension
self.z_channels = model_config['z_channels']
self.norm_channels = model_config['norm_channels']
self.num_heads = model_config['num_heads']
# Assertion to validate the channel information
assert self.mid_channels[0] == self.down_channels[-1]
assert self.mid_channels[-1] == self.down_channels[-1]
assert len(self.down_sample) == len(self.down_channels) - 1
assert len(self.attns) == len(self.down_channels) - 1
# Wherever we use downsampling in encoder correspondingly use
# upsampling in decoder
self.up_sample = list(reversed(self.down_sample))
##################### Encoder ######################
self.encoder_conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))
# Downblock + Midblock
self.encoder_layers = nn.ModuleList([])
for i in range(len(self.down_channels) - 1):
self.encoder_layers.append(DownBlock(self.down_channels[i], self.down_channels[i + 1],
t_emb_dim=None, down_sample=self.down_sample[i],
num_heads=self.num_heads,
num_layers=self.num_down_layers,
attn=self.attns[i],
norm_channels=self.norm_channels))
self.encoder_mids = nn.ModuleList([])
for i in range(len(self.mid_channels) - 1):
self.encoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1],
t_emb_dim=None,
num_heads=self.num_heads,
num_layers=self.num_mid_layers,
norm_channels=self.norm_channels))
self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1])
self.encoder_conv_out = nn.Conv2d(self.down_channels[-1], 2*self.z_channels, kernel_size=3, padding=1)
# Latent Dimension is 2*Latent because we are predicting mean & variance
self.pre_quant_conv = nn.Conv2d(2*self.z_channels, 2*self.z_channels, kernel_size=1)
####################################################
##################### Decoder ######################
self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1)
self.decoder_conv_in = nn.Conv2d(self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1))
# Midblock + Upblock
self.decoder_mids = nn.ModuleList([])
for i in reversed(range(1, len(self.mid_channels))):
self.decoder_mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i - 1],
t_emb_dim=None,
num_heads=self.num_heads,
num_layers=self.num_mid_layers,
norm_channels=self.norm_channels))
self.decoder_layers = nn.ModuleList([])
for i in reversed(range(1, len(self.down_channels))):
self.decoder_layers.append(UpBlock(self.down_channels[i], self.down_channels[i - 1],
t_emb_dim=None, up_sample=self.down_sample[i - 1],
num_heads=self.num_heads,
num_layers=self.num_up_layers,
attn=self.attns[i - 1],
norm_channels=self.norm_channels))
self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0])
self.decoder_conv_out = nn.Conv2d(self.down_channels[0], im_channels, kernel_size=3, padding=1)
def encode(self, x):
out = self.encoder_conv_in(x)
for idx, down in enumerate(self.encoder_layers):
out = down(out)
for mid in self.encoder_mids:
out = mid(out)
out = self.encoder_norm_out(out)
out = nn.SiLU()(out)
out = self.encoder_conv_out(out)
out = self.pre_quant_conv(out)
mean, logvar = torch.chunk(out, 2, dim=1)
std = torch.exp(0.5 * logvar)
sample = mean + std * torch.randn(mean.shape).to(device=x.device)
return sample, out
def decode(self, z):
out = z
out = self.post_quant_conv(out)
out = self.decoder_conv_in(out)
for mid in self.decoder_mids:
out = mid(out)
for idx, up in enumerate(self.decoder_layers):
out = up(out)
out = self.decoder_norm_out(out)
out = nn.SiLU()(out)
out = self.decoder_conv_out(out)
return out
def forward(self, x):
z, encoder_output = self.encode(x)
out = self.decode(z)
return out, encoder_output
"""Discriminator"""
import torch
import torch.nn as nn
class Discriminator(nn.Module):
r"""
PatchGAN Discriminator.
Rather than taking IMG_CHANNELSxIMG_HxIMG_W all the way to
1 scalar value , we instead predict grid of values.
Where each grid is prediction of how likely
the discriminator thinks that the image patch corresponding
to the grid cell is real
"""
def __init__(self, im_channels=3,
conv_channels=[64, 128, 256],
kernels=[4,4,4,4],
strides=[2,2,2,1],
paddings=[1,1,1,1]):
super().__init__()
self.im_channels = im_channels
activation = nn.LeakyReLU(0.2)
layers_dim = [self.im_channels] + conv_channels + [1]
self.layers = nn.ModuleList([
nn.Sequential(
nn.Conv2d(layers_dim[i], layers_dim[i + 1],
kernel_size=kernels[i],
stride=strides[i],
padding=paddings[i],
bias=False if i !=0 else True),
nn.BatchNorm2d(layers_dim[i + 1]) if i != len(layers_dim) - 2 and i != 0 else nn.Identity(),
activation if i != len(layers_dim) - 2 else nn.Identity()
)
for i in range(len(layers_dim) - 1)
])
def forward(self, x):
out = x
for layer in self.layers:
out = layer(out)
return out
# if __name__ == '__main__':
# x = torch.randn((2,3, 256, 256))
# prob = Discriminator(im_channels=3)(x)
# print(prob.shape)
# import os
# image_paths = [os.path.join("/home/taruntejaneurips23/Ashish/datasets/animefacedata/images", f)
# for f in os.listdir("/home/taruntejaneurips23/Ashish/datasets/animefacedata/images")]
# image_paths
import glob
import os
import torchvision
from PIL import Image
from tqdm import tqdm, trange
# from utils.diffusion_utils import load_latents
from torch.utils.data.dataset import Dataset
import pickle
import glob
import os
import torch
def load_latents(latent_path):
r"""
Simple utility to save latents to speed up ldm training
:param latent_path:
:return:
"""
latent_maps = {}
for fname in glob.glob(os.path.join(latent_path, '*.pkl')):
s = pickle.load(open(fname, 'rb'))
for k, v in s.items():
latent_maps[k] = v[0]
return latent_maps
def drop_text_condition(text_embed, im, empty_text_embed, text_drop_prob):
if text_drop_prob > 0:
text_drop_mask = torch.zeros((im.shape[0]), device=im.device).float().uniform_(0,
1) < text_drop_prob
assert empty_text_embed is not None, ("Text Conditioning required as well as"
" text dropping but empty text representation not created")
text_embed[text_drop_mask, :, :] = empty_text_embed[0]
return text_embed
def drop_image_condition(image_condition, im, im_drop_prob):
if im_drop_prob > 0:
im_drop_mask = torch.zeros((im.shape[0], 1, 1, 1), device=im.device).float().uniform_(0,
1) > im_drop_prob
return image_condition * im_drop_mask
else:
return image_condition
def drop_class_condition(class_condition, class_drop_prob, im):
if class_drop_prob > 0:
class_drop_mask = torch.zeros((im.shape[0], 1), device=im.device).float().uniform_(0,
1) > class_drop_prob
return class_condition * class_drop_mask
else:
return class_condition
class MnistDataset(Dataset):
r"""
Nothing special here. Just a simple dataset class for mnist images.
Created a dataset class rather using torchvision to allow
replacement with any other image dataset
"""
def __init__(self, split, im_path, im_size, im_channels,
use_latents=False, latent_path=None, condition_config=None):
r"""
Init method for initializing the dataset properties
:param split: train/test to locate the image files
:param im_path: root folder of images
:param im_ext: image extension. assumes all
images would be this type.
"""
self.split = split
self.im_size = im_size
self.im_channels = im_channels
# Should we use latents or not
self.latent_maps = None
self.use_latents = False
# Conditioning for the dataset
self.condition_types = [] if condition_config is None else condition_config['condition_types']
self.images, self.labels = self.load_images(im_path)
# Whether to load images and call vae or to load latents
if use_latents and latent_path is not None:
latent_maps = load_latents(latent_path)
if len(latent_maps) == len(self.images):
self.use_latents = True
self.latent_maps = latent_maps
print('Found {} latents'.format(len(self.latent_maps)))
else:
print('Latents not found')
def load_images(self, im_path):
r"""
Gets all images from the path specified
and stacks them all up
:param im_path:
:return:
"""
assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
ims = []
labels = []
for d_name in tqdm(os.listdir(im_path)):
fnames = glob.glob(os.path.join(im_path, d_name, '*.{}'.format('png')))
fnames += glob.glob(os.path.join(im_path, d_name, '*.{}'.format('jpg')))
fnames += glob.glob(os.path.join(im_path, d_name, '*.{}'.format('jpeg')))
for fname in fnames:
ims.append(fname)
if 'class' in self.condition_types:
labels.append(int(d_name))
print('Found {} images for split {}'.format(len(ims), self.split))
return ims, labels
def __len__(self):
return len(self.images)
def __getitem__(self, index):
######## Set Conditioning Info ########
cond_inputs = {}
if 'class' in self.condition_types:
cond_inputs['class'] = self.labels[index]
#######################################
if self.use_latents:
latent = self.latent_maps[self.images[index]]
if len(self.condition_types) == 0:
return latent
else:
return latent, cond_inputs
else:
im = Image.open(self.images[index])
im_tensor = torchvision.transforms.ToTensor()(im)
# Convert input to -1 to 1 range.
im_tensor = (2 * im_tensor) - 1
if len(self.condition_types) == 0:
return im_tensor
else:
return im_tensor, cond_inputs
class AnimeFaceDataset(Dataset):
def __init__(self, split, im_path, im_size, im_channels,
use_latents=False, latent_path=None, condition_config=None):
self.split = split
self.im_size = im_size
self.im_channels = im_channels
# Should we use latents or not
self.latent_maps = None
self.use_latents = False
# Conditioning for the dataset
self.condition_types = [] if condition_config is None else condition_config['condition_types']
self.images = self.load_images(im_path)
# Whether to load images and call vae or to load latents
if use_latents and latent_path is not None:
latent_maps = load_latents(latent_path)
if len(latent_maps) == len(self.images):
self.use_latents = True
self.latent_maps = latent_maps
print('Found {} latents'.format(len(self.latent_maps)))
else:
print('Latents not found')
def load_images(self, im_path):
r"""
Gets all images from the path specified
and stacks them all up
:param im_path:
:return:
"""
assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
# ims = []
# labels = []
ims = [os.path.join(im_path, f) for f in os.listdir(im_path)]
return ims
def __len__(self):
return len(self.images)
def __getitem__(self, index):
######## Set Conditioning Info ########
# cond_inputs = {}
# if 'class' in self.condition_types:
# cond_inputs['class'] = self.labels[index]
#######################################
if self.use_latents:
latent = self.latent_maps[self.images[index]]
if len(self.condition_types) == 0:
return latent
# else:
# return latent, cond_inputs
else:
im = Image.open(self.images[index])
im_tensor = torchvision.transforms.Compose([
torchvision.transforms.Resize(self.im_size),
torchvision.transforms.CenterCrop(self.im_size),
torchvision.transforms.ToTensor(),
])(im)
im.close()
# im_tensor = torchvision.transforms.ToTensor()(im)
# Convert input to -1 to 1 range.
im_tensor = (2 * im_tensor) - 1
if len(self.condition_types) == 0:
return im_tensor
# else:
# return im_tensor, cond_inputs
import glob
import os
import random
import torch
import torchvision
import numpy as np
from PIL import Image
from tqdm import tqdm
from torch.utils.data.dataset import Dataset
class CelebDataset(Dataset):
def __init__(self, split, im_path, im_size, im_channels,
use_latents=False, latent_path=None, condition_config=None):
self.split = split
self.im_size = im_size
self.im_channels = im_channels
# Should we use latents or not
self.latent_maps = None
self.use_latents = False
# Conditioning for the dataset
self.condition_types = [] if condition_config is None else condition_config['condition_types']
self.images = self.load_images(im_path)
# Whether to load images and call vae or to load latents
if use_latents and latent_path is not None:
latent_maps = load_latents(latent_path)
if len(latent_maps) == len(self.images):
self.use_latents = True
self.latent_maps = latent_maps
print('Found {} latents'.format(len(self.latent_maps)))
else:
print('Latents not found')
def load_images(self, im_path):
r"""
Gets all images from the path specified
and stacks them all up
:param im_path:
:return:
"""
assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
# ims = []
# labels = []
ims = [os.path.join(im_path, f) for f in os.listdir(im_path)]
return ims
def __len__(self):
return len(self.images)
def __getitem__(self, index):
######## Set Conditioning Info ########
# cond_inputs = {}
# if 'class' in self.condition_types:
# cond_inputs['class'] = self.labels[index]
#######################################
if self.use_latents:
latent = self.latent_maps[self.images[index]]
if len(self.condition_types) == 0:
return latent
# else:
# return latent, cond_inputs
else:
im = Image.open(self.images[index])
im_tensor = torchvision.transforms.Compose([
# torchvision.transforms.Resize(self.im_size),
torchvision.transforms.CenterCrop(self.im_size),
torchvision.transforms.ToTensor(),
])(im)
im.close()
# im_tensor = torchvision.transforms.ToTensor()(im)
# Convert input to -1 to 1 range.
im_tensor = (2 * im_tensor) - 1
if len(self.condition_types) == 0:
return im_tensor
# else:
# return im_tensor, cond_inputs
import pandas as pd
class CelebHairDataset(Dataset):
def __init__(self, split, im_path, im_size, im_channels,
use_latents=False, latent_path=None, condition_config=None):
self.df = pd.read_csv("/home/taruntejaneurips23/Ashish/DDPM/hair_df_100.csv")
self.split = split
self.im_size = im_size
self.im_channels = im_channels
# Should we use latents or not
self.latent_maps = None
self.use_latents = False
# Conditioning for the dataset
self.condition_types = [] if condition_config is None else condition_config['condition_types']
self.images = self.load_images(im_path, self.df)
# Whether to load images and call vae or to load latents
if use_latents and latent_path is not None:
latent_maps = load_latents(latent_path)
if len(latent_maps) == len(self.images):
self.use_latents = True
self.latent_maps = latent_maps
print('Found {} latents'.format(len(self.latent_maps)))
else:
print('Latents not found')
def load_images(self, im_path, df):
r"""
Gets all images from the path specified
and stacks them all up
:param im_path:
:return:
"""
assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
# ims = []
# labels = []
# ims = [os.path.join(im_path, f) for f in os.listdir(im_path)]
ims = [os.path.join(im_path, i) for i in df.image_id.values]
return ims
def __len__(self):
return len(self.images)
def __getitem__(self, index):
######## Set Conditioning Info ########
# cond_inputs = {}
# if 'class' in self.condition_types:
# cond_inputs['class'] = self.labels[index]
#######################################
if self.use_latents:
latent = self.latent_maps[self.images[index]]
if len(self.condition_types) == 0:
return latent
# else:
# return latent, cond_inputs
else:
im = Image.open(self.images[index])
im_tensor = torchvision.transforms.Compose([
# torchvision.transforms.Resize(self.im_size),
torchvision.transforms.CenterCrop(self.im_size),
torchvision.transforms.ToTensor(),
])(im)
im.close()
# im_tensor = torchvision.transforms.ToTensor()(im)
# Convert input to -1 to 1 range.
im_tensor = (2 * im_tensor) - 1
if len(self.condition_types) == 0:
return im_tensor
# else:
# return im_tensor, cond_inputs
#"""Train VQVAE"""...............................................................................................................................................
# Commented out IPython magic to ensure Python compatibility.
import torch
import torch.nn as nn
import yaml
from dotdict import DotDict
config_path = "/home/taruntejaneurips23/Ashish/DDPM/_5_ldm_celeba.yaml"
with open(config_path, 'r') as file:
Config = yaml.safe_load(file)
Config = DotDict.from_dict(Config)
dataset_config = Config.dataset_params
diffusion_config = Config.diffusion_params
model_config = Config.model_params
train_config = Config.train_params
import torch
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.optim import Adam
from torch.utils.data import Dataset, TensorDataset, DataLoader
# device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
from torchvision.utils import make_grid
def trainVAE(Config):
dataset_config = Config.dataset_params
autoencoder_config = Config.autoencoder_params
train_config = Config.train_params
# Set the desired seed value #
seed = train_config.seed
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if device == 'cuda':
torch.cuda.manual_seed_all(seed)
#############################
# Create the model and dataset #
model = VQVAE(im_channels=dataset_config.im_channels,
model_config=autoencoder_config).to(device)
# model.load_state_dict(torch.load("/home/taruntejaneurips23/Ashish/DDPM/celebAhair_ldm/vqvae_autoencoder_ckpt.pth", map_location=device))
if os.path.exists(os.path.join(train_config.task_name, train_config.vqvae_autoencoder_ckpt_name)):
print('Loaded vae checkpoint')
model.load_state_dict(torch.load(os.path.join(train_config.task_name, train_config.vqvae_autoencoder_ckpt_name),
map_location=device, weights_only=True))
# Create the dataset
im_dataset_cls = {
'mnist': MnistDataset,
'celebA': CelebDataset,
'animeface': AnimeFaceDataset,
'celebAhair': CelebHairDataset
}.get(dataset_config.name)
im_dataset = im_dataset_cls(split='train',
im_path=dataset_config.im_path,
im_size=dataset_config.im_size,
im_channels=dataset_config.im_channels)
data_loader = DataLoader(im_dataset,
batch_size=train_config.autoencoder_batch_size,
shuffle=True,
num_workers=os.cpu_count(),
pin_memory=True,
drop_last=True,
persistent_workers=True, pin_memory_device=device)
# Create output directories
if not os.path.exists(train_config.task_name):
os.mkdir(train_config.task_name)
num_epochs = train_config.autoencoder_epochs
# L1/L2 loss for Reconstruction
recon_criterion = torch.nn.MSELoss()
# Disc Loss can even be BCEWithLogits
disc_criterion = torch.nn.MSELoss()
# No need to freeze lpips as lpips.py takes care of that
lpips_model = LPIPS().eval().to(device)
discriminator = Discriminator(im_channels=dataset_config.im_channels).to(device)
# discriminator.load_state_dict(torch.load("/home/taruntejaneurips23/Ashish/DDPM/celebAhair_ldm/vqvae_discriminator_ckpt.pth", map_location=device))
if os.path.exists(os.path.join(train_config.task_name, train_config.vqvae_discriminator_ckpt_name)):
print('Loaded discriminator checkpoint')
discriminator.load_state_dict(torch.load(os.path.join(train_config.task_name, train_config.vqvae_discriminator_ckpt_name),
map_location=device, weights_only=True))
optimizer_d = Adam(discriminator.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999))
optimizer_g = Adam(model.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999))
disc_step_start = train_config.disc_start
step_count = 0
# This is for accumulating gradients incase the images are huge
# And one cant afford higher batch sizes
acc_steps = train_config.autoencoder_acc_steps
image_save_steps = train_config.autoencoder_img_save_steps
img_save_count = 0
for epoch_idx in trange(num_epochs, desc='Training VQVAE'):
recon_losses = []
codebook_losses = []
#commitment_losses = []
perceptual_losses = []
disc_losses = []
gen_losses = []
losses = []
optimizer_g.zero_grad()
optimizer_d.zero_grad()
# for im in tqdm(data_loader):
for im in data_loader:
step_count += 1
im = im.float().to(device)
# Fetch autoencoders output(reconstructions)
model_output = model(im)
output, z, quantize_losses = model_output
# Image Saving Logic
if step_count % image_save_steps == 0 or step_count == 1:
sample_size = min(8, im.shape[0])
save_output = torch.clamp(output[:sample_size], -1., 1.).detach().cpu()
save_output = ((save_output + 1) / 2)
save_input = ((im[:sample_size] + 1) / 2).detach().cpu()
grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size)
img = torchvision.transforms.ToPILImage()(grid)
if not os.path.exists(os.path.join(train_config.task_name,'vqvae_autoencoder_samples')):
os.mkdir(os.path.join(train_config.task_name, 'vqvae_autoencoder_samples'))
img.save(os.path.join(train_config.task_name,'vqvae_autoencoder_samples',
'current_autoencoder_sample_{}.png'.format(img_save_count)))
img_save_count += 1
img.close()
######### Optimize Generator ##########
# L2 Loss
recon_loss = recon_criterion(output, im)
recon_losses.append(recon_loss.item())
recon_loss = recon_loss / acc_steps
g_loss = (recon_loss +
(train_config.codebook_weight * quantize_losses['codebook_loss'] / acc_steps) +
(train_config.commitment_beta * quantize_losses['commitment_loss'] / acc_steps))
codebook_losses.append(train_config.codebook_weight * quantize_losses['codebook_loss'].item())
# Adversarial loss only if disc_step_start steps passed
if step_count > disc_step_start:
disc_fake_pred = discriminator(model_output[0])
disc_fake_loss = disc_criterion(disc_fake_pred,
torch.ones(disc_fake_pred.shape,
device=disc_fake_pred.device))
gen_losses.append(train_config.disc_weight * disc_fake_loss.item())
g_loss += train_config.disc_weight * disc_fake_loss / acc_steps
lpips_loss = torch.mean(lpips_model(output, im)) / acc_steps
perceptual_losses.append(train_config.perceptual_weight * lpips_loss.item())
g_loss += train_config.perceptual_weight*lpips_loss / acc_steps
losses.append(g_loss.item())
g_loss.backward()
#####################################
######### Optimize Discriminator #######
if step_count > disc_step_start:
fake = output
disc_fake_pred = discriminator(fake.detach())
disc_real_pred = discriminator(im)
disc_fake_loss = disc_criterion(disc_fake_pred,
torch.zeros(disc_fake_pred.shape,
device=disc_fake_pred.device))
disc_real_loss = disc_criterion(disc_real_pred,
torch.ones(disc_real_pred.shape,
device=disc_real_pred.device))
disc_loss = train_config.disc_weight * (disc_fake_loss + disc_real_loss) / 2
disc_losses.append(disc_loss.item())
disc_loss = disc_loss / acc_steps
disc_loss.backward()
if step_count % acc_steps == 0:
optimizer_d.step()
optimizer_d.zero_grad()
#####################################
if step_count % acc_steps == 0:
optimizer_g.step()
optimizer_g.zero_grad()
optimizer_d.step()
optimizer_d.zero_grad()
optimizer_g.step()
optimizer_g.zero_grad()
if len(disc_losses) > 0:
print(
'Finished epoch: {}/{} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | '
'Codebook : {:.4f} | G Loss : {:.4f} | D Loss {:.4f}'.
format(epoch_idx + 1,
num_epochs,
np.mean(recon_losses),
np.mean(perceptual_losses),
np.mean(codebook_losses),
np.mean(gen_losses),
np.mean(disc_losses)))
else:
print('Finished epoch: {}/{} | Recon Loss : {:.4f} | Perceptual Loss : {:.4f} | Codebook : {:.4f}'.
format(epoch_idx + 1,
num_epochs,
np.mean(recon_losses),
np.mean(perceptual_losses),
np.mean(codebook_losses)))
torch.save(model.state_dict(), os.path.join(train_config.task_name,
train_config.vqvae_autoencoder_ckpt_name))
torch.save(discriminator.state_dict(), os.path.join(train_config.task_name,
train_config.vqvae_discriminator_ckpt_name))
print('Done Training...')
# trainVAE(Config)
import torch
import torch.nn as nn
class Unet(nn.Module):
r"""
Unet model comprising
Down blocks, Midblocks and Uplocks
"""
def __init__(self, im_channels, model_config):
super().__init__()
self.down_channels = model_config.down_channels
self.mid_channels = model_config.mid_channels
self.t_emb_dim = model_config.time_emb_dim
self.down_sample = model_config.down_sample
self.num_down_layers = model_config.num_down_layers
self.num_mid_layers = model_config.num_mid_layers
self.num_up_layers = model_config.num_up_layers
self.attns = model_config.attn_down
self.norm_channels = model_config.norm_channels
self.num_heads = model_config.num_heads
self.conv_out_channels = model_config.conv_out_channels
assert self.mid_channels[0] == self.down_channels[-1]
assert self.mid_channels[-1] == self.down_channels[-2]
assert len(self.down_sample) == len(self.down_channels) - 1
assert len(self.attns) == len(self.down_channels) - 1
# Initial projection from sinusoidal time embedding
self.t_proj = nn.Sequential(
nn.Linear(self.t_emb_dim, self.t_emb_dim),
nn.SiLU(),
nn.Linear(self.t_emb_dim, self.t_emb_dim)
)
self.up_sample = list(reversed(self.down_sample))
self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=1)
self.downs = nn.ModuleList([])
for i in range(len(self.down_channels) - 1):
self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i + 1], self.t_emb_dim,
down_sample=self.down_sample[i],
num_heads=self.num_heads,
num_layers=self.num_down_layers,
attn=self.attns[i], norm_channels=self.norm_channels))
self.mids = nn.ModuleList([])
for i in range(len(self.mid_channels) - 1):
self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1], self.t_emb_dim,
num_heads=self.num_heads,
num_layers=self.num_mid_layers,
norm_channels=self.norm_channels))
self.ups = nn.ModuleList([])
for i in reversed(range(len(self.down_channels) - 1)):
self.ups.append(UpBlockUnet(self.down_channels[i] * 2, self.down_channels[i - 1] if i != 0 else self.conv_out_channels,
self.t_emb_dim, up_sample=self.down_sample[i],
num_heads=self.num_heads,
num_layers=self.num_up_layers,
norm_channels=self.norm_channels))
self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels)
self.conv_out = nn.Conv2d(self.conv_out_channels, im_channels, kernel_size=3, padding=1)
def forward(self, x, t):
# Shapes assuming downblocks are [C1, C2, C3, C4]
# Shapes assuming midblocks are [C4, C4, C3]
# Shapes assuming downsamples are [True, True, False]
# B x C x H x W
out = self.conv_in(x)
# B x C1 x H x W
# t_emb -> B x t_emb_dim
t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim)
t_emb = self.t_proj(t_emb)
down_outs = []
for idx, down in enumerate(self.downs):
down_outs.append(out)
out = down(out, t_emb)
# down_outs [B x C1 x H x W, B x C2 x H/2 x W/2, B x C3 x H/4 x W/4]
# out B x C4 x H/4 x W/4
for mid in self.mids:
out = mid(out, t_emb)
# out B x C3 x H/4 x W/4
for up in self.ups:
down_out = down_outs.pop()
out = up(out, down_out, t_emb)
# out [B x C2 x H/4 x W/4, B x C1 x H/2 x W/2, B x 16 x H x W]
out = self.norm_out(out)
out = nn.SiLU()(out)
out = self.conv_out(out)
# out B x C x H x W
return out
def trainLDM(Config):
diffusion_config = Config.diffusion_params
dataset_config = Config.dataset_params
diffusion_model_config = Config.ldm_params
autoencoder_model_config = Config.autoencoder_params
train_config = Config.train_params
# Create the noise scheduler
scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config.num_timesteps,
beta_start=diffusion_config.beta_start,
beta_end=diffusion_config.beta_end)
# scheduler = CosineNoiseScheduler(diffusion_config.num_timesteps)
im_dataset_cls = {
'mnist': MnistDataset,
'celebA': CelebDataset,
'animeface': AnimeFaceDataset,
'celebAhair': CelebHairDataset
}.get(dataset_config.name)
im_dataset = im_dataset_cls(split='train',
im_path=dataset_config.im_path,
im_size=dataset_config.im_size,
im_channels=dataset_config.im_channels,
use_latents=True,
latent_path=os.path.join(train_config.task_name,
train_config.vqvae_latent_dir_name)
)
data_loader = DataLoader(im_dataset,
batch_size=train_config.ldm_batch_size,
shuffle=True,
num_workers=os.cpu_count(),
pin_memory=True,
drop_last=False,
persistent_workers=True, pin_memory_device=device)
# Instantiate the model
model = Unet(im_channels=autoencoder_model_config.z_channels,
model_config=diffusion_model_config).to(device)
if os.path.exists(os.path.join(train_config.task_name, train_config.ldm_ckpt_name)):
print('Loaded ldm checkpoint')
model.load_state_dict(torch.load(os.path.join(train_config.task_name, train_config.ldm_ckpt_name), map_location=device, weights_only=True))
model.train()
# Load VAE ONLY if latents are not to be used or are missing
if not im_dataset.use_latents:
print('Loading vqvae model as latents not present')
vae = VQVAE(im_channels=dataset_config.im_channels,
model_config=autoencoder_model_config).to(device)
vae.eval()
# Load vae if found
if os.path.exists(os.path.join(train_config.task_name,
train_config.vqvae_autoencoder_ckpt_name)):
print('Loaded vae checkpoint')
vae.load_state_dict(torch.load(os.path.join(train_config.task_name,
train_config.vqvae_autoencoder_ckpt_name),
map_location=device))
# Specify training parameters
num_epochs = train_config.ldm_epochs
optimizer = Adam(model.parameters(), lr=train_config.ldm_lr)
criterion = torch.nn.MSELoss()
# Run training
if not im_dataset.use_latents:
for param in vae.parameters():
param.requires_grad = False
for epoch_idx in range(num_epochs):
losses = []
for im in tqdm(data_loader):
optimizer.zero_grad()
im = im.float().to(device)
if not im_dataset.use_latents:
with torch.no_grad():
im, _ = vae.encode(im)
# Sample random noise
noise = torch.randn_like(im).to(device)
# Sample timestep
t = torch.randint(0, diffusion_config.num_timesteps, (im.shape[0],)).to(device)
# Add noise to images according to timestep
noisy_im = scheduler.add_noise(im, noise, t)
noise_pred = model(noisy_im, t)
loss = criterion(noise_pred, noise)
losses.append(loss.item())
loss.backward()
optimizer.step()
print(f'Finished epoch:{epoch_idx + 1}/{num_epochs} | Loss : {np.mean(losses):.4f}')
torch.save(model.state_dict(), os.path.join(train_config.task_name,
train_config.ldm_ckpt_name))
# Doing Inference
infer(Config)
# Checking to conntinue training
train_continue = yaml.safe_load(open("/home/taruntejaneurips23/Ashish/DDPM/_5_ldm_celeba.yaml", 'r'))
train_continue = DotDict.from_dict(train_continue)
if train_continue.training._continue_ == False:
print('Training Stoped ...')
break
print('Done Training ...')
# trainLDM(Config)
# import subprocess
# subprocess.run(f'kill {os.getpid()}', shell=True, check=True)
def sample(model, scheduler, train_config, diffusion_model_config,
autoencoder_model_config, diffusion_config, dataset_config, vae):
r"""
Sample stepwise by going backward one timestep at a time.
We save the x0 predictions
"""
im_size = dataset_config.im_size // 2**sum(autoencoder_model_config.down_sample)
xt = torch.randn((train_config.num_samples,
autoencoder_model_config.z_channels,
im_size,
im_size)).to(device)
save_count = 0
for i in tqdm(reversed(range(diffusion_config.num_timesteps)), total=diffusion_config.num_timesteps):
# Get prediction of noise
noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device))
# Use scheduler to get x0 and xt-1
xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))
# Save x0
#ims = torch.clamp(xt, -1., 1.).detach().cpu()
if i == 0:
# Decode ONLY the final iamge to save time
ims = vae.decode(xt)
else:
ims = xt
ims = torch.clamp(ims, -1., 1.).detach().cpu()
ims = (ims + 1) / 2
grid = make_grid(ims, nrow=train_config.num_grid_rows)
img = torchvision.transforms.ToPILImage()(grid)
if not os.path.exists(os.path.join(train_config.task_name, 'samples')):
os.mkdir(os.path.join(train_config.task_name, 'samples'))
img.save(os.path.join(train_config.task_name, 'samples', 'x0_{}.png'.format(i)))
img.close()
def infer(Config):
diffusion_config = Config.diffusion_params
dataset_config = Config.dataset_params
diffusion_model_config = Config.ldm_params
autoencoder_model_config = Config.autoencoder_params
train_config = Config.train_params
# Create the noise scheduler
scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config.num_timesteps,
beta_start=diffusion_config.beta_start,
beta_end=diffusion_config.beta_end)
# scheduler = CosineNoiseScheduler(diffusion_config.num_timesteps)
model = Unet(im_channels=autoencoder_model_config.z_channels,
model_config=diffusion_model_config).to(device)
model.eval()
if os.path.exists(os.path.join(train_config.task_name,
train_config.ldm_ckpt_name)):
print('Loaded unet checkpoint')
model.load_state_dict(torch.load(os.path.join(train_config.task_name,
train_config.ldm_ckpt_name),
map_location=device))
# Create output directories
if not os.path.exists(train_config.task_name):
os.mkdir(train_config.task_name)
vae = VQVAE(im_channels=dataset_config.im_channels,
model_config=autoencoder_model_config).to(device)
vae.eval()
# Load vae if found
if os.path.exists(os.path.join(train_config.task_name,
train_config.vqvae_autoencoder_ckpt_name)):
print('Loaded vae checkpoint')
vae.load_state_dict(torch.load(os.path.join(train_config.task_name,
train_config.vqvae_autoencoder_ckpt_name),
map_location=device), strict=True)
with torch.no_grad():
sample(model, scheduler, train_config, diffusion_model_config,
autoencoder_model_config, diffusion_config, dataset_config, vae)
import argparse
def get_args():
parser = argparse.ArgumentParser(description="Choose between train VAE, train LDM, or infer mode.")
parser.add_argument('--mode', choices=['train_vae', 'train_ldm', 'infer'], default='infer',
help="Mode to run: train_vae, train_ldm, or infer")
return parser.parse_args()
args = get_args()
if args.mode == 'train_vae':
trainVAE(Config)
elif args.mode == 'train_ldm':
trainLDM(Config)
else:
infer(Config)
# python _5.2_ldm_celeba_hair_cosine.py --mode train_vae
# python _5.2_ldm_celeba_hair_cosine.py --mode train_ldm
# python _5.2_ldm_celeba_hair_cosine.py --mode infer
# import matplotlib.pyplot as plt
# from PIL import Image
# # plt.style.use('dark_background')
# # %matplotlib inline
# plt.imshow(Image.open('/home/taruntejaneurips23/Ashish/DDPM/mnist_ldm/samples/x0_0.png'), cmap='gray')
# import matplotlib.pyplot as plt
# import matplotlib.image as mpimg
# dataset_name = 'animeface_ldm'
# image_paths = [f'/home/taruntejaneurips23/Ashish/DDPM/{dataset_name}/samples/x0_0.png',
# f'/home/taruntejaneurips23/Ashish/DDPM/{dataset_name}/samples/x0_1.png',
# f'/home/taruntejaneurips23/Ashish/DDPM/{dataset_name}/samples/x0_5.png',
# f'/home/taruntejaneurips23/Ashish/DDPM/{dataset_name}/samples/x0_100.png',
# f'/home/taruntejaneurips23/Ashish/DDPM/{dataset_name}/samples/x0_200.png'
# ]
# fig, axes = plt.subplots(1, len(image_paths), figsize=(15, 5))
# for i, path in enumerate(image_paths):
# img = mpimg.imread(path)
# axes[i].imshow(img)
# axes[i].axis('off') # Hide axes
# axes[i].set_title(f't = {path.split("/")[-1].split(".")[0].split("_")[-1]}')
# plt.tight_layout()
# plt.show()
# ---------------------------------------------------------
# ---------- T H E - E N D -------------------------------
# ---------------------------------------------------------
def save_checkpoint(
total_steps, epoch, model, discriminator,
optimizer_d, optimizer_g, loss, checkpoint_path
):
checkpoint = {
"total_steps": total_steps,
"epoch": epoch,
"model_state_dict": model.state_dict(),
"discriminator_state_dict": discriminator.state_dict(),
"optimizer_d_state_dict": optimizer_d.state_dict(),
"optimizer_g_state_dict": optimizer_g.state_dict(),
"loss": loss,
}
torch.save(checkpoint, checkpoint_path)
print(f"Checkpoint saved after {total_steps} steps at epoch {epoch}")
def load_checkpoint(
checkpoint_path, model, discriminator, optimizer_d, optimizer_g
):
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["model_state_dict"])
discriminator.load_state_dict(checkpoint["discriminator_state_dict"])
optimizer_d.load_state_dict(checkpoint["optimizer_d_state_dict"])
optimizer_g.load_state_dict(checkpoint["optimizer_g_state_dict"])
total_steps = checkpoint["total_steps"]
start_epoch = checkpoint["epoch"] + 1
loss = checkpoint["loss"]
print(f"Checkpoint loaded. Resuming from epoch {start_epoch}")
return total_steps, start_epoch, loss
else:
print("No checkpoint found. Starting from scratch.")
return 0, 0, None
def trainVAE(Config, dataloader):
"""
Trains a VQVAE model using the provided configuration and data loader.
"""
# --- Configurations ----------------------------------------------------
dataset_config = Config.dataset_params
autoencoder_config = Config.autoencoder_params
train_config = Config.train_params
seed = train_config.seed
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if device == "cuda":
torch.cuda.manual_seed_all(seed)
# --- Model Initialization ----------------------------------------------
model = VQVAE(im_channels=dataset_config.im_channels, model_config=autoencoder_config).to(device)
discriminator = Discriminator(im_channels=dataset_config.im_channels).to(device)
# --- Load Checkpoints --------------------------------------------------
checkpoint_path = os.path.join(train_config.task_name, "vqvae_checkpoint.pth")
total_steps, start_epoch, _ = load_checkpoint(checkpoint_path, model, discriminator, None, None)
# --- Loss Function Initialization --------------------------------------
recon_criterion = torch.nn.MSELoss()
lpips_model = LPIPS().eval().to(device)
disc_criterion = torch.nn.MSELoss()
# --- Optimizer Initialization ------------------------------------------
optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999))
optimizer_g = torch.optim.AdamW(model.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999))
num_epochs = train_config.autoencoder_epochs
acc_steps = train_config.autoencoder_acc_steps
image_save_steps = train_config.autoencoder_img_save_steps
img_save_count = 0
# Create necessary directories
os.makedirs(os.path.join(train_config.task_name, "vqvae_autoencoder_samples"), exist_ok=True)
# --- Training Loop -----------------------------------------------------
for epoch_idx in range(start_epoch, num_epochs):
recon_losses, codebook_losses, perceptual_losses, disc_losses, gen_losses = [], [], [], [], []
for images in dataloader:
total_steps += 1
images = images.to(device)
# Forward pass
model_output = model(images)
output, z, quantize_losses = model_output
# Save generated images periodically
if total_steps % image_save_steps == 0 or total_steps == 1:
sample_size = min(8, images.shape[0])
save_output = torch.clamp(output[:sample_size], -1.0, 1.0).detach().cpu()
save_output = (save_output + 1) / 2
save_input = ((images[:sample_size] + 1) / 2).detach().cpu()
grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size)
img = tv.transforms.ToPILImage()(grid)
img.save(
os.path.join(
train_config.task_name,
"vqvae_autoencoder_samples",
f"current_autoencoder_sample_{img_save_count}.png",
)
)
img_save_count += 1
img.close()
# Reconstruction Loss
recon_loss = recon_criterion(output, images) / acc_steps
recon_losses.append(recon_loss.item())
# Generator Loss
codebook_loss = train_config.codebook_weight * quantize_losses["codebook_loss"] / acc_steps
perceptual_loss = train_config.perceptual_weight * lpips_model(output, images).mean() / acc_steps
g_loss = recon_loss + codebook_loss + perceptual_loss
if total_steps > train_config.disc_start:
disc_fake_pred = discriminator(output)
gen_loss = train_config.disc_weight * disc_criterion(
disc_fake_pred, torch.ones_like(disc_fake_pred)
) / acc_steps
g_loss += gen_loss
gen_losses.append(gen_loss.item())
g_loss.backward()
optimizer_g.step()
optimizer_g.zero_grad()
# Discriminator Loss
if total_steps > train_config.disc_start:
disc_fake_pred = discriminator(output.detach())
disc_real_pred = discriminator(images)
disc_fake_loss = disc_criterion(
disc_fake_pred, torch.zeros_like(disc_fake_pred)
) / acc_steps
disc_real_loss = disc_criterion(
disc_real_pred, torch.ones_like(disc_real_pred)
) / acc_steps
disc_loss = train_config.disc_weight * (disc_fake_loss + disc_real_loss) / 2
disc_loss.backward()
optimizer_d.step()
optimizer_d.zero_grad()
disc_losses.append(disc_loss.item())
# Save checkpoint after each epoch
save_checkpoint(total_steps, epoch_idx, model, discriminator, optimizer_d, optimizer_g, recon_losses, checkpoint_path)
# Print epoch summary
print(
f"Epoch {epoch_idx + 1}/{num_epochs} | Recon Loss: {np.mean(recon_losses):.4f} | "
f"Perceptual Loss: {np.mean(perceptual_losses):.4f} | Codebook Loss: {np.mean(codebook_losses):.4f} | "
f"G Loss: {np.mean(gen_losses):.4f} | D Loss: {np.mean(disc_losses):.4f}"
)