File size: 3,337 Bytes
4035e2e
 
 
 
 
 
 
 
 
 
 
 
 
6a03f22
4035e2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager

from ldm.modules.distributions.distributions import DiagonalGaussianDistribution

from ldm.util import instantiate_from_config
from ldm.modules.ema import LitEma

try:
    from modules.models import Encoder, Decoder
except:
    from modi_vae.models import Encoder, Decoder

class AutoencoderKL(pl.LightningModule):
    def __init__(self,
                 embed_dim=4,
                 ckpt_path=None,
                 ignore_keys=[],
                 image_key="image",
                 colorize_nlabels=None,
                 monitor=None,
                 ema_decay=None,
                 learn_logvar=False,
                 load_checkpoint=True
                 ):
        super().__init__()
        self.encoder = Encoder(double_z=True, z_channels=4, resolution=256, in_channels=3, out_ch=3, ch=128, ch_mult=[1,2,4,4], num_res_blocks=2, attn_resolutions=[], dropout=0.0)
        self.decoder = Decoder(double_z=True, z_channels=4, resolution=256, in_channels=3, out_ch=3, ch=128, ch_mult=[1,2,4,4], num_res_blocks=2, attn_resolutions=[], dropout=0.0)

        self.quant_conv = torch.nn.Conv2d(2*4, 2*embed_dim, 1)
        self.post_quant_conv = torch.nn.Conv2d(embed_dim, 4, 1)
        self.embed_dim = embed_dim
        if colorize_nlabels is not None:
            assert type(colorize_nlabels)==int
            self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
        if monitor is not None:
            self.monitor = monitor
        
        if load_checkpoint:
            state_dict = torch.load('/data07/v-wenjwang/ControlNet/CIConv/models/control_sd15_ini.ckpt', map_location=torch.device("cpu"))
            new_state_dict = {}
            for s in state_dict:
                if "first_stage_model" in s:
                    new_state_dict[s.replace("first_stage_model.", "")] = state_dict[s]
            self.load_state_dict(new_state_dict, strict=False)

    def encode(self, x):
        h, hs = self.encoder(x)
        moments = self.quant_conv(h)
        posterior = DiagonalGaussianDistribution(moments)
        return posterior, hs

    def decode(self, z, hs):
        z = self.post_quant_conv(z)
        dec = self.decoder(z, hs)
        return dec

    def forward(self, input, sample_posterior=True):
        posterior, hs = self.encode(input)
        if sample_posterior:
            z = posterior.sample()
        else:
            z = posterior.mode()
        dec = self.decode(z, hs)
        return dec, posterior

if __name__ == "__main__":
    from data.laion_dataset import create_webdataset
    import torchvision

    image_dataset = create_webdataset(
        data_dir="/data06/v-wenjwang/COCO-2017/*/*.*",
    )

    import webdataset as wds
    image_dataloader = wds.WebLoader(
        dataset          =   image_dataset,
        batch_size       =   1,
        num_workers      =   8,
        pin_memory       =   True,
        prefetch_factor  =   2,
    )

    model = AutoencoderKL().cuda()

    for data in image_dataloader:
        img = data["distorted"].cuda()
        img = model(img)[0]

        torchvision.utils.save_image(img*0.5+0.5, "distorted.png")
        torchvision.utils.save_image(data["distorted"]*0.5+0.5, "original.png")

        break