File size: 5,915 Bytes
4035e2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
try:
    from modules.models import Encoder, Decoder
except:
    from my_vae.models import Encoder, Decoder


class AutoencoderKL(pl.LightningModule):
    def __init__(self,
                 embed_dim=4,
                 ckpt_path=None,
                 ignore_keys=[],
                 image_key="image",
                 colorize_nlabels=None,
                 monitor=None,
                 ema_decay=None,
                 learn_logvar=False,
                 load_checkpoint=True,
                 lr=1e-4,
                 ):
        super().__init__()
        self.save_hyperparameters(ignore=["ckpt_path", "ignore_keys", "colorize_nlabels"])
        self.image_key = image_key
        self.lr = lr

        self.encoder = Encoder(double_z=True, z_channels=4, resolution=256, in_channels=3,
                               out_ch=3, ch=128, ch_mult=[1,2,4,4], num_res_blocks=2,
                               attn_resolutions=[], dropout=0.0)
        self.decoder = Decoder(double_z=True, z_channels=4, resolution=256, in_channels=3,
                               out_ch=3, ch=128, ch_mult=[1,2,4,4], num_res_blocks=2,
                               attn_resolutions=[], dropout=0.0)

        self.quant_conv = nn.Conv2d(2*4, 2*embed_dim, 1)
        self.post_quant_conv = nn.Conv2d(embed_dim, 4, 1)
        self.embed_dim = embed_dim

        if colorize_nlabels is not None:
            assert isinstance(colorize_nlabels, int)
            self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
        if monitor is not None:
            self.monitor = monitor

        if load_checkpoint:
            state_dict = torch.load('/home/xxing/model/ControlNet/checkpoints/main-epoch=00-step=7000.ckpt', map_location=torch.device("cpu"))["state_dict"]
            new_state_dict = {}
            for s in state_dict:
                if "my_vae" in s:
                    new_state_dict[s.replace("my_vae.", "")] = state_dict[s]
            self.load_state_dict(new_state_dict)
            print("Successfully load new auto-encoder")


        # By default, prepare for decoder-only finetuning
        self.freeze_encoder()

    # ---------- core VAE pieces ----------
    def encode(self, x):
        h, hs = self.encoder(x)
        moments = self.quant_conv(h)
        posterior = DiagonalGaussianDistribution(moments)
        return posterior, hs

    def decode(self, z, hs):
        z = self.post_quant_conv(z)
        dec = self.decoder(z, hs)
        return dec

    def forward(self, input, sample_posterior=True):
        posterior, hs = self.encode(input)
        z = posterior.sample() if sample_posterior else posterior.mode()
        dec = self.decode(z, hs)
        return dec, posterior

    # ---------- training for decoder only ----------
    @torch.no_grad()
    def _encode_nograd(self, x):
        """Encode without gradients; used to prevent updates to encoder/quant_conv."""
        posterior, hs = self.encode(x)
        # Detach so gradients don't flow back to encoder/quant_conv
        z = posterior.sample().detach()
        hs = [h.detach() if isinstance(h, torch.Tensor) else h for h in hs]
        return z, hs

    def training_step(self, batch, batch_idx):
        # Expect batch to be a dict with 'image' or a tensor directly
        x = batch[self.image_key] if isinstance(batch, dict) else batch  # [B,3,H,W] in [-1,1] or [0,1]
        z, _ = self._encode_nograd(x[:,:3,...])
        _, hs = self._encode_nograd(x[:,3:,...])
        x_hat = self.decode(z, hs)

        # Simple reconstruction loss (L1). If inputs are in [-1,1], it's fine for L1 too.
        rec_loss = F.l1_loss(x_hat, x[:,:3,...])

        # (Optional) small MSE term to stabilize
        mse_loss = F.mse_loss(x_hat, x[:,:3,...])
        loss = rec_loss + 0.1 * mse_loss

        self.log_dict({
            "train/l1": rec_loss,
            "train/mse": mse_loss,
            "train/loss": loss
        }, prog_bar=True, on_step=True, on_epoch=True, batch_size=x.shape[0])
        return loss

    def validation_step(self, batch, batch_idx):
        x = batch[self.image_key] if isinstance(batch, dict) else batch
        z,_ = self._encode_nograd(x[:,:3,...])
        _, hs = self._encode_nograd(x[:,3:,...])
        
        x_hat = self.decode(z, hs)
        rec_loss = F.l1_loss(x_hat, x[:,:3,...])
        mse_loss = F.mse_loss(x_hat, x[:,:3,...])
        loss = rec_loss + 0.1 * mse_loss
        self.log_dict({
            "val/l1": rec_loss,
            "val/mse": mse_loss,
            "val/loss": loss
        }, prog_bar=True, on_epoch=True, batch_size=x.shape[0])

    def configure_optimizers(self):
        # Only optimize decoder + post_quant_conv
        params = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters())
        opt = torch.optim.Adam(params, lr=self.lr, betas=(0.9, 0.999))
        return opt

    def freeze_encoder(self):
        """Freeze encoder and quant_conv (no grads)."""
        for p in self.encoder.parameters():
            p.requires_grad = False
        for p in self.quant_conv.parameters():
            p.requires_grad = False
        # Ensure decoder & post_quant_conv are trainable
        for p in self.decoder.parameters():
            p.requires_grad = True
        for p in self.post_quant_conv.parameters():
            p.requires_grad = True


# import torch
# import torchvision.transforms as T

# # Compose transforms
# transform = T.Compose([
#     T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),  # illumination jitter
#     T.Lambda(lambda x: x + 0.05 * torch.randn_like(x))  # Gaussian noise
# ])

# # Example: I is [C,H,W] in [0,1] or [-1,1]
# I_aug = transform(I).clamp(-1, 1)  # keep in valid range