File size: 5,915 Bytes
4035e2e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
try:
from modules.models import Encoder, Decoder
except:
from my_vae.models import Encoder, Decoder
class AutoencoderKL(pl.LightningModule):
def __init__(self,
embed_dim=4,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False,
load_checkpoint=True,
lr=1e-4,
):
super().__init__()
self.save_hyperparameters(ignore=["ckpt_path", "ignore_keys", "colorize_nlabels"])
self.image_key = image_key
self.lr = lr
self.encoder = Encoder(double_z=True, z_channels=4, resolution=256, in_channels=3,
out_ch=3, ch=128, ch_mult=[1,2,4,4], num_res_blocks=2,
attn_resolutions=[], dropout=0.0)
self.decoder = Decoder(double_z=True, z_channels=4, resolution=256, in_channels=3,
out_ch=3, ch=128, ch_mult=[1,2,4,4], num_res_blocks=2,
attn_resolutions=[], dropout=0.0)
self.quant_conv = nn.Conv2d(2*4, 2*embed_dim, 1)
self.post_quant_conv = nn.Conv2d(embed_dim, 4, 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert isinstance(colorize_nlabels, int)
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
if load_checkpoint:
state_dict = torch.load('/home/xxing/model/ControlNet/checkpoints/main-epoch=00-step=7000.ckpt', map_location=torch.device("cpu"))["state_dict"]
new_state_dict = {}
for s in state_dict:
if "my_vae" in s:
new_state_dict[s.replace("my_vae.", "")] = state_dict[s]
self.load_state_dict(new_state_dict)
print("Successfully load new auto-encoder")
# By default, prepare for decoder-only finetuning
self.freeze_encoder()
# ---------- core VAE pieces ----------
def encode(self, x):
h, hs = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior, hs
def decode(self, z, hs):
z = self.post_quant_conv(z)
dec = self.decoder(z, hs)
return dec
def forward(self, input, sample_posterior=True):
posterior, hs = self.encode(input)
z = posterior.sample() if sample_posterior else posterior.mode()
dec = self.decode(z, hs)
return dec, posterior
# ---------- training for decoder only ----------
@torch.no_grad()
def _encode_nograd(self, x):
"""Encode without gradients; used to prevent updates to encoder/quant_conv."""
posterior, hs = self.encode(x)
# Detach so gradients don't flow back to encoder/quant_conv
z = posterior.sample().detach()
hs = [h.detach() if isinstance(h, torch.Tensor) else h for h in hs]
return z, hs
def training_step(self, batch, batch_idx):
# Expect batch to be a dict with 'image' or a tensor directly
x = batch[self.image_key] if isinstance(batch, dict) else batch # [B,3,H,W] in [-1,1] or [0,1]
z, _ = self._encode_nograd(x[:,:3,...])
_, hs = self._encode_nograd(x[:,3:,...])
x_hat = self.decode(z, hs)
# Simple reconstruction loss (L1). If inputs are in [-1,1], it's fine for L1 too.
rec_loss = F.l1_loss(x_hat, x[:,:3,...])
# (Optional) small MSE term to stabilize
mse_loss = F.mse_loss(x_hat, x[:,:3,...])
loss = rec_loss + 0.1 * mse_loss
self.log_dict({
"train/l1": rec_loss,
"train/mse": mse_loss,
"train/loss": loss
}, prog_bar=True, on_step=True, on_epoch=True, batch_size=x.shape[0])
return loss
def validation_step(self, batch, batch_idx):
x = batch[self.image_key] if isinstance(batch, dict) else batch
z,_ = self._encode_nograd(x[:,:3,...])
_, hs = self._encode_nograd(x[:,3:,...])
x_hat = self.decode(z, hs)
rec_loss = F.l1_loss(x_hat, x[:,:3,...])
mse_loss = F.mse_loss(x_hat, x[:,:3,...])
loss = rec_loss + 0.1 * mse_loss
self.log_dict({
"val/l1": rec_loss,
"val/mse": mse_loss,
"val/loss": loss
}, prog_bar=True, on_epoch=True, batch_size=x.shape[0])
def configure_optimizers(self):
# Only optimize decoder + post_quant_conv
params = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters())
opt = torch.optim.Adam(params, lr=self.lr, betas=(0.9, 0.999))
return opt
def freeze_encoder(self):
"""Freeze encoder and quant_conv (no grads)."""
for p in self.encoder.parameters():
p.requires_grad = False
for p in self.quant_conv.parameters():
p.requires_grad = False
# Ensure decoder & post_quant_conv are trainable
for p in self.decoder.parameters():
p.requires_grad = True
for p in self.post_quant_conv.parameters():
p.requires_grad = True
# import torch
# import torchvision.transforms as T
# # Compose transforms
# transform = T.Compose([
# T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1), # illumination jitter
# T.Lambda(lambda x: x + 0.05 * torch.randn_like(x)) # Gaussian noise
# ])
# # Example: I is [C,H,W] in [0,1] or [-1,1]
# I_aug = transform(I).clamp(-1, 1) # keep in valid range
|