File size: 3,337 Bytes
4035e2e 6a03f22 4035e2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ldm.util import instantiate_from_config
from ldm.modules.ema import LitEma
try:
from modules.models import Encoder, Decoder
except:
from modi_vae.models import Encoder, Decoder
class AutoencoderKL(pl.LightningModule):
def __init__(self,
embed_dim=4,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False,
load_checkpoint=True
):
super().__init__()
self.encoder = Encoder(double_z=True, z_channels=4, resolution=256, in_channels=3, out_ch=3, ch=128, ch_mult=[1,2,4,4], num_res_blocks=2, attn_resolutions=[], dropout=0.0)
self.decoder = Decoder(double_z=True, z_channels=4, resolution=256, in_channels=3, out_ch=3, ch=128, ch_mult=[1,2,4,4], num_res_blocks=2, attn_resolutions=[], dropout=0.0)
self.quant_conv = torch.nn.Conv2d(2*4, 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, 4, 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
if load_checkpoint:
state_dict = torch.load('/data07/v-wenjwang/ControlNet/CIConv/models/control_sd15_ini.ckpt', map_location=torch.device("cpu"))
new_state_dict = {}
for s in state_dict:
if "first_stage_model" in s:
new_state_dict[s.replace("first_stage_model.", "")] = state_dict[s]
self.load_state_dict(new_state_dict, strict=False)
def encode(self, x):
h, hs = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior, hs
def decode(self, z, hs):
z = self.post_quant_conv(z)
dec = self.decoder(z, hs)
return dec
def forward(self, input, sample_posterior=True):
posterior, hs = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z, hs)
return dec, posterior
if __name__ == "__main__":
from data.laion_dataset import create_webdataset
import torchvision
image_dataset = create_webdataset(
data_dir="/data06/v-wenjwang/COCO-2017/*/*.*",
)
import webdataset as wds
image_dataloader = wds.WebLoader(
dataset = image_dataset,
batch_size = 1,
num_workers = 8,
pin_memory = True,
prefetch_factor = 2,
)
model = AutoencoderKL().cuda()
for data in image_dataloader:
img = data["distorted"].cuda()
img = model(img)[0]
torchvision.utils.save_image(img*0.5+0.5, "distorted.png")
torchvision.utils.save_image(data["distorted"]*0.5+0.5, "original.png")
break |