Spaces:
Runtime error
Runtime error
| from typing import List | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import torch.nn as nn | |
| from src.plugin.ldm.modules.diffusionmodules.model import Encoder, Decoder | |
| from src.plugin.ldm.modules.distributions.distributions import DiagonalGaussianDistribution | |
| class AutoencoderKL(nn.Module): | |
| def __init__( | |
| self, | |
| double_z: bool = True, | |
| z_channels: int = 3, | |
| resolution: int = 512, | |
| in_channels: int = 3, | |
| out_ch: int = 3, | |
| ch: int = 128, | |
| ch_mult: List = [1, 2, 4, 4], | |
| num_res_blocks: int = 2, | |
| attn_resolutions: List = [], | |
| dropout: float = 0.0, | |
| embed_dim: int = 3, | |
| ckpt_path: str = None, | |
| ignore_keys: List = [], | |
| ): | |
| super(AutoencoderKL, self).__init__() | |
| ddconfig = { | |
| "double_z": double_z, | |
| "z_channels": z_channels, | |
| "resolution": resolution, | |
| "in_channels": in_channels, | |
| "out_ch": out_ch, | |
| "ch": ch, | |
| "ch_mult": ch_mult, | |
| "num_res_blocks": num_res_blocks, | |
| "attn_resolutions": attn_resolutions, | |
| "dropout": dropout | |
| } | |
| self.encoder = Encoder(**ddconfig) | |
| self.decoder = Decoder(**ddconfig) | |
| assert ddconfig["double_z"] | |
| self.quant_conv = nn.Conv2d( | |
| 2 * ddconfig["z_channels"], 2 * embed_dim, 1) | |
| self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) | |
| self.embed_dim = embed_dim | |
| if ckpt_path is not None: | |
| self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) | |
| def init_from_ckpt(self, path, ignore_keys=list()): | |
| sd = torch.load(path, map_location="cpu")["state_dict"] | |
| keys = list(sd.keys()) | |
| for k in keys: | |
| for ik in ignore_keys: | |
| if k.startswith(ik): | |
| print(f"Deleting key {k} from state_dict.") | |
| del sd[k] | |
| self.load_state_dict(sd, strict=False) | |
| print(f"Restored from {path}") | |
| def encode(self, x): | |
| h = self.encoder(x) # B, C, h, w | |
| moments = self.quant_conv(h) # B, 6, h, w | |
| posterior = DiagonalGaussianDistribution(moments) | |
| return posterior # 分布 | |
| def decode(self, z): | |
| z = self.post_quant_conv(z) | |
| dec = self.decoder(z) | |
| return dec | |
| def forward(self, input, sample_posterior=True): | |
| posterior = self.encode(input) # 高斯分布 | |
| if sample_posterior: | |
| z = posterior.sample() # 采样 | |
| else: | |
| z = posterior.mode() | |
| dec = self.decode(z) | |
| last_layer_weight = self.decoder.conv_out.weight | |
| return dec, posterior, last_layer_weight | |
| if __name__ == '__main__': | |
| # Test the input and output shapes of the model | |
| model = AutoencoderKL() | |
| x = torch.randn(1, 3, 512, 512) | |
| dec, posterior, last_layer_weight = model(x) | |
| assert dec.shape == (1, 3, 512, 512) | |
| assert posterior.sample().shape == posterior.mode().shape == (1, 3, 64, 64) | |
| assert last_layer_weight.shape == (3, 128, 3, 3) | |
| # Plot the latent space and the reconstruction from the pretrained model | |
| model = AutoencoderKL(ckpt_path="/mnt/chongqinggeminiceph1fs/geminicephfs/wx-mm-spr-xxxx/zouxuechao/Collaborative-Diffusion/outputs/512_vae/2024-06-27T06-02-04_512_vae/checkpoints/epoch=000036.ckpt") | |
| model.eval() | |
| image_path = "data/celeba/image/image_512_downsampled_from_hq_1024/0.jpg" | |
| from PIL import Image | |
| import numpy as np | |
| from src.data.components.celeba import DalleTransformerPreprocessor | |
| image = Image.open(image_path).convert('RGB') | |
| image = np.array(image).astype(np.uint8) | |
| import copy | |
| original = copy.deepcopy(image) | |
| transform = DalleTransformerPreprocessor(size=512, phase='test') | |
| image = transform(image=image)['image'] | |
| image = image.astype(np.float32)/127.5 - 1.0 | |
| image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0) | |
| dec, posterior, last_layer_weight = model(image) | |
| # original image | |
| plt.subplot(1, 3, 1) | |
| plt.imshow(original) | |
| plt.title("Original") | |
| plt.axis("off") | |
| # sampled image from the latent space | |
| plt.subplot(1, 3, 2) | |
| x = model.decode(posterior.sample()) | |
| x = (x+1)/2 | |
| x = x.squeeze(0).permute(1, 2, 0).cpu() | |
| x = x.detach().numpy() | |
| x = x.clip(0, 1) | |
| x = (x*255).astype(np.uint8) | |
| plt.imshow(x) | |
| plt.title("Sampled") | |
| plt.axis("off") | |
| # reconstructed image | |
| plt.subplot(1, 3, 3) | |
| x = dec | |
| x = (x+1)/2 | |
| x = x.squeeze(0).permute(1, 2, 0).cpu() | |
| x = x.detach().numpy() | |
| x = x.clip(0, 1) | |
| x = (x*255).astype(np.uint8) | |
| plt.imshow(x) | |
| plt.title("Reconstructed") | |
| plt.axis("off") | |
| plt.tight_layout() | |
| plt.savefig("vae_reconstruction.png") | |