|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import pytorch_lightning as pl |
|
|
|
|
|
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution |
|
|
try: |
|
|
from modules.models import Encoder, Decoder |
|
|
except: |
|
|
from my_vae.models import Encoder, Decoder |
|
|
|
|
|
|
|
|
class AutoencoderKL(pl.LightningModule): |
|
|
def __init__(self, |
|
|
embed_dim=4, |
|
|
ckpt_path=None, |
|
|
ignore_keys=[], |
|
|
image_key="image", |
|
|
colorize_nlabels=None, |
|
|
monitor=None, |
|
|
ema_decay=None, |
|
|
learn_logvar=False, |
|
|
load_checkpoint=True, |
|
|
lr=1e-4, |
|
|
): |
|
|
super().__init__() |
|
|
self.save_hyperparameters(ignore=["ckpt_path", "ignore_keys", "colorize_nlabels"]) |
|
|
self.image_key = image_key |
|
|
self.lr = lr |
|
|
|
|
|
self.encoder = Encoder(double_z=True, z_channels=4, resolution=256, in_channels=3, |
|
|
out_ch=3, ch=128, ch_mult=[1,2,4,4], num_res_blocks=2, |
|
|
attn_resolutions=[], dropout=0.0) |
|
|
self.decoder = Decoder(double_z=True, z_channels=4, resolution=256, in_channels=3, |
|
|
out_ch=3, ch=128, ch_mult=[1,2,4,4], num_res_blocks=2, |
|
|
attn_resolutions=[], dropout=0.0) |
|
|
|
|
|
self.quant_conv = nn.Conv2d(2*4, 2*embed_dim, 1) |
|
|
self.post_quant_conv = nn.Conv2d(embed_dim, 4, 1) |
|
|
self.embed_dim = embed_dim |
|
|
|
|
|
if colorize_nlabels is not None: |
|
|
assert isinstance(colorize_nlabels, int) |
|
|
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) |
|
|
if monitor is not None: |
|
|
self.monitor = monitor |
|
|
|
|
|
if load_checkpoint: |
|
|
state_dict = torch.load('/home/xxing/model/ControlNet/checkpoints/main-epoch=00-step=7000.ckpt', map_location=torch.device("cpu"))["state_dict"] |
|
|
new_state_dict = {} |
|
|
for s in state_dict: |
|
|
if "my_vae" in s: |
|
|
new_state_dict[s.replace("my_vae.", "")] = state_dict[s] |
|
|
self.load_state_dict(new_state_dict) |
|
|
print("Successfully load new auto-encoder") |
|
|
|
|
|
|
|
|
|
|
|
self.freeze_encoder() |
|
|
|
|
|
|
|
|
def encode(self, x): |
|
|
h, hs = self.encoder(x) |
|
|
moments = self.quant_conv(h) |
|
|
posterior = DiagonalGaussianDistribution(moments) |
|
|
return posterior, hs |
|
|
|
|
|
def decode(self, z, hs): |
|
|
z = self.post_quant_conv(z) |
|
|
dec = self.decoder(z, hs) |
|
|
return dec |
|
|
|
|
|
def forward(self, input, sample_posterior=True): |
|
|
posterior, hs = self.encode(input) |
|
|
z = posterior.sample() if sample_posterior else posterior.mode() |
|
|
dec = self.decode(z, hs) |
|
|
return dec, posterior |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def _encode_nograd(self, x): |
|
|
"""Encode without gradients; used to prevent updates to encoder/quant_conv.""" |
|
|
posterior, hs = self.encode(x) |
|
|
|
|
|
z = posterior.sample().detach() |
|
|
hs = [h.detach() if isinstance(h, torch.Tensor) else h for h in hs] |
|
|
return z, hs |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
|
|
|
x = batch[self.image_key] if isinstance(batch, dict) else batch |
|
|
z, _ = self._encode_nograd(x[:,:3,...]) |
|
|
_, hs = self._encode_nograd(x[:,3:,...]) |
|
|
x_hat = self.decode(z, hs) |
|
|
|
|
|
|
|
|
rec_loss = F.l1_loss(x_hat, x[:,:3,...]) |
|
|
|
|
|
|
|
|
mse_loss = F.mse_loss(x_hat, x[:,:3,...]) |
|
|
loss = rec_loss + 0.1 * mse_loss |
|
|
|
|
|
self.log_dict({ |
|
|
"train/l1": rec_loss, |
|
|
"train/mse": mse_loss, |
|
|
"train/loss": loss |
|
|
}, prog_bar=True, on_step=True, on_epoch=True, batch_size=x.shape[0]) |
|
|
return loss |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
x = batch[self.image_key] if isinstance(batch, dict) else batch |
|
|
z,_ = self._encode_nograd(x[:,:3,...]) |
|
|
_, hs = self._encode_nograd(x[:,3:,...]) |
|
|
|
|
|
x_hat = self.decode(z, hs) |
|
|
rec_loss = F.l1_loss(x_hat, x[:,:3,...]) |
|
|
mse_loss = F.mse_loss(x_hat, x[:,:3,...]) |
|
|
loss = rec_loss + 0.1 * mse_loss |
|
|
self.log_dict({ |
|
|
"val/l1": rec_loss, |
|
|
"val/mse": mse_loss, |
|
|
"val/loss": loss |
|
|
}, prog_bar=True, on_epoch=True, batch_size=x.shape[0]) |
|
|
|
|
|
def configure_optimizers(self): |
|
|
|
|
|
params = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters()) |
|
|
opt = torch.optim.Adam(params, lr=self.lr, betas=(0.9, 0.999)) |
|
|
return opt |
|
|
|
|
|
def freeze_encoder(self): |
|
|
"""Freeze encoder and quant_conv (no grads).""" |
|
|
for p in self.encoder.parameters(): |
|
|
p.requires_grad = False |
|
|
for p in self.quant_conv.parameters(): |
|
|
p.requires_grad = False |
|
|
|
|
|
for p in self.decoder.parameters(): |
|
|
p.requires_grad = True |
|
|
for p in self.post_quant_conv.parameters(): |
|
|
p.requires_grad = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|