Spaces:
Sleeping
Sleeping
File size: 4,001 Bytes
4aabce3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 | import sys, os
sys.path.insert(0, os.path.dirname(__file__))
import torch
from vae import VAE
from config import *
from PIL import Image
from torchvision.utils import save_image
from dataloader import ImageDataset
from tqdm import tqdm
from torch.utils.data import DataLoader
import torchvision.transforms as T
def sample_latents(vae):
os.makedirs(latent_recon_images, exist_ok=True)
with torch.no_grad():
z = torch.randn(16, 4, vae_latent_dim, vae_latent_dim).to(device)
samples = vae.decoder_conv(z)
save_image(samples, f"{latent_recon_images}/random_samples.png", nrow=4)
def reconstruct(vae, device, epoch=""):
os.makedirs(latent_recon_images, exist_ok=True)
with torch.no_grad():
image_paths = [os.path.join(resized_img_dir, path) for path in sorted(os.listdir(resized_img_dir))[:16] if path.endswith(".jpg")]
recon_loader = DataLoader(ImageDataset(image_paths), batch_size=16, shuffle=False)
for x in tqdm(recon_loader, desc=f"Progress", colour="#FF0000"):
x = x.to(device)
recon_x, _, _ = vae(x)
save_image(torch.cat([x, recon_x]), f"{latent_recon_images}/recon{epoch}.png",nrow=16)
def save_latents(vae, device):
os.makedirs(latent_dir, exist_ok=True)
with torch.no_grad():
for file_name in tqdm(os.listdir(resized_img_dir), desc=f"Progress"):
# if file_name not in os.listdir(latent_dir):
transform = T.Compose([T.ToTensor()])
img = transform(Image.open(f"{resized_img_dir}/{file_name}").convert("RGB"))
img = img.unsqueeze(0).to(device)
z = vae.encode_img_to_latent(img)
torch.save(z.squeeze(0), f"{latent_dir}/{file_name[:-4]}.pt") # (16, 16, 16)
def compute_mean_std_over_latents(device):
all_sum = torch.tensor(0.0).to(device)
all_sq_sum = torch.tensor(0.0).to(device)
count = 0
with torch.no_grad():
for file_name in tqdm(os.listdir(latent_dir)[:], desc=f"Progress"):
z = torch.load(f"{latent_dir}/{file_name}", map_location="cpu", weights_only=False).to(device)
all_sum += z.sum()
all_sq_sum += (z ** 2).sum()
count += z.numel()
mean = all_sum / count
var = (all_sq_sum / count) - mean ** 2
std = torch.sqrt(var)
print("Mean over latent:", mean.item()) # -0.0074263461865484715
print("STD over latent:", std.item()) # 0.8829671740531921
print("Latent Scale:", (1 / std).item()) # 1.1325448751449585
return mean.item(), std.item()
def scale_latents(latents_scale, device):
os.makedirs(latent_scaled_dir, exist_ok=True)
with torch.no_grad():
for file_name in tqdm(os.listdir(latent_dir), desc=f"Progress", colour=tqdm_colors[5]):
z = torch.load(f"{latent_dir}/{file_name}", map_location="cpu", weights_only=False).to(device)
z *= latents_scale
# if file_name not in os.listdir(latent_dir):
torch.save(z, f"{latent_scaled_dir}/{file_name}") # (16, 16, 16)
def normalize_latents(mu, std, device):
os.makedirs(latent_norm_dir, exist_ok=True)
with torch.no_grad():
for file_name in tqdm(os.listdir(latent_dir), desc=f"Progress", colour=tqdm_colors[5]):
z = torch.load(f"{latent_dir}/{file_name}", map_location="cpu", weights_only=False).to(device)
z = (z - mu) / std
# if file_name not in os.listdir(latent_dir):
torch.save(z, f"{latent_norm_dir}/{file_name}") # (16, 16, 16)
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = VAE().to(device)
path = f"{vae_checkpoint_dir}/{vae_weight}"
checkpoint = torch.load(path, map_location=device, weights_only=False)
vae.load_state_dict(checkpoint["vae"])
vae.eval()
# sample_latents(vae)
save_latents(vae, device)
mu, std = compute_mean_std_over_latents(device)
scale_latents(1 / std, device)
# normalize_latents(mu, std, device) |