devanshsrivastav's picture
Add files using upload-large-folder tool
6f02465 verified
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
import sys
sys.path.append('.')
from stable_diffusion.ldm.modules.diffusionmodules.model import Encoder, Decoder
from stable_diffusion.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from stable_diffusion.ldm.util import instantiate_from_config
class AutoencoderKL(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig, # torch.nn.Identity
embed_dim, # embed_dim = 4
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None, # This is None
monitor=None, # val/rec_loss
):
super().__init__()
self.image_key = image_key # 'image'
# The encoder and decoder are reverse in the VQVAE
# The encoder encodes the image to a latent space, and then transfer it to a Gaussian Distribution
self.encoder = Encoder(**ddconfig)
# Note, the output of the encoder is NOT directly fed into the decoder. The output channel size of the encoder is 2 * z_channel, as identified by the ddconfig['double_z']. This is becuase the output of the encoder is used to construct a Gaussian Distribution
# The decoder decodes the latent space to an image
self.decoder = Decoder(**ddconfig)
# torch.nn.Identity
self.loss = instantiate_from_config(lossconfig) # Identity function
# double_z = True.
assert ddconfig["double_z"]
# z_channels = 4
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
# embed_dim = 4
self.embed_dim = embed_dim
# colorize_nlabels is None
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
# monitor = val/rec_loss
if monitor is not None:
self.monitor = monitor
# ckpt_path = None, the checkpoint loading of stable diffusion is conducted outside
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
def encode(self, x):
# x: [bs, 3, 256, 256], h: [bs, 8, 32, 32]
h = self.encoder(x)
# serves as the mean and variance of the Gaussian distribution (halve the last dim)
# moments: [bs, 8, 32, 32]
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
# z: [bs, 4, 32, 32]
z = self.post_quant_conv(z)
# z: [bs, 4, 32, 32]
dec = self.decoder(z)
# dec: [bs, 3, 256, 256]
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample() # a normal sampling
else:
z = posterior.mode() # returns the mean
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
return x
def training_step(self, batch, batch_idx, optimizer_idx): # in Stable Diffusion we use pretrained VAE and freeze it.
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return discloss
def validation_step(self, batch, batch_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
last_layer=self.get_last_layer(), split="val")
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
last_layer=self.get_last_layer(), split="val")
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
log["inputs"] = x
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x