File size: 4,001 Bytes
4aabce3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import sys, os
sys.path.insert(0, os.path.dirname(__file__))
import torch
from vae import VAE
from config import *
from PIL import Image
from torchvision.utils import save_image
from dataloader import ImageDataset
from tqdm import tqdm
from torch.utils.data import DataLoader
import torchvision.transforms as T

def sample_latents(vae):
    os.makedirs(latent_recon_images, exist_ok=True)
    with torch.no_grad():
        z = torch.randn(16, 4, vae_latent_dim, vae_latent_dim).to(device)
        samples = vae.decoder_conv(z)
        save_image(samples, f"{latent_recon_images}/random_samples.png", nrow=4)

def reconstruct(vae, device, epoch=""):
    os.makedirs(latent_recon_images, exist_ok=True)
    with torch.no_grad():
        image_paths = [os.path.join(resized_img_dir, path) for path in sorted(os.listdir(resized_img_dir))[:16] if path.endswith(".jpg")]
        recon_loader = DataLoader(ImageDataset(image_paths), batch_size=16, shuffle=False)
        for x in tqdm(recon_loader, desc=f"Progress", colour="#FF0000"):
            x = x.to(device)
            recon_x, _, _ = vae(x)
            save_image(torch.cat([x, recon_x]), f"{latent_recon_images}/recon{epoch}.png",nrow=16)

def save_latents(vae, device):
    os.makedirs(latent_dir, exist_ok=True)
    with torch.no_grad():
        for file_name in tqdm(os.listdir(resized_img_dir), desc=f"Progress"):
            # if file_name not in os.listdir(latent_dir):
            transform = T.Compose([T.ToTensor()])
            img = transform(Image.open(f"{resized_img_dir}/{file_name}").convert("RGB"))
            img = img.unsqueeze(0).to(device)
            z = vae.encode_img_to_latent(img)
            torch.save(z.squeeze(0), f"{latent_dir}/{file_name[:-4]}.pt") # (16, 16, 16)

def compute_mean_std_over_latents(device):
    all_sum = torch.tensor(0.0).to(device)
    all_sq_sum = torch.tensor(0.0).to(device)
    count = 0
    with torch.no_grad():
        for file_name in tqdm(os.listdir(latent_dir)[:], desc=f"Progress"):
            z = torch.load(f"{latent_dir}/{file_name}", map_location="cpu", weights_only=False).to(device)
            all_sum += z.sum()
            all_sq_sum += (z ** 2).sum()
            count += z.numel()
    mean = all_sum / count
    var = (all_sq_sum / count) - mean ** 2
    std = torch.sqrt(var)
    print("Mean over latent:", mean.item()) # -0.0074263461865484715
    print("STD over latent:", std.item()) #  0.8829671740531921
    print("Latent Scale:", (1 / std).item()) # 1.1325448751449585
    return mean.item(), std.item()

def scale_latents(latents_scale, device):
    os.makedirs(latent_scaled_dir, exist_ok=True)
    with torch.no_grad():
        for file_name in tqdm(os.listdir(latent_dir), desc=f"Progress", colour=tqdm_colors[5]):
            z = torch.load(f"{latent_dir}/{file_name}", map_location="cpu", weights_only=False).to(device)
            z *= latents_scale
            # if file_name not in os.listdir(latent_dir):
            torch.save(z, f"{latent_scaled_dir}/{file_name}") # (16, 16, 16)

def normalize_latents(mu, std, device):
    os.makedirs(latent_norm_dir, exist_ok=True)
    with torch.no_grad():
        for file_name in tqdm(os.listdir(latent_dir), desc=f"Progress", colour=tqdm_colors[5]):
            z = torch.load(f"{latent_dir}/{file_name}", map_location="cpu", weights_only=False).to(device)
            z = (z - mu) / std
            # if file_name not in os.listdir(latent_dir):
            torch.save(z, f"{latent_norm_dir}/{file_name}") # (16, 16, 16)

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    vae = VAE().to(device)
    path = f"{vae_checkpoint_dir}/{vae_weight}"
    checkpoint = torch.load(path, map_location=device, weights_only=False)
    vae.load_state_dict(checkpoint["vae"])
    vae.eval()
    # sample_latents(vae)
    save_latents(vae, device)
    mu, std = compute_mean_std_over_latents(device)
    scale_latents(1 / std, device)
    # normalize_latents(mu, std, device)