import sys, os sys.path.insert(0, os.path.dirname(__file__)) import torch from vae import VAE from config import * from PIL import Image from torchvision.utils import save_image from dataloader import ImageDataset from tqdm import tqdm from torch.utils.data import DataLoader import torchvision.transforms as T def sample_latents(vae): os.makedirs(latent_recon_images, exist_ok=True) with torch.no_grad(): z = torch.randn(16, 4, vae_latent_dim, vae_latent_dim).to(device) samples = vae.decoder_conv(z) save_image(samples, f"{latent_recon_images}/random_samples.png", nrow=4) def reconstruct(vae, device, epoch=""): os.makedirs(latent_recon_images, exist_ok=True) with torch.no_grad(): image_paths = [os.path.join(resized_img_dir, path) for path in sorted(os.listdir(resized_img_dir))[:16] if path.endswith(".jpg")] recon_loader = DataLoader(ImageDataset(image_paths), batch_size=16, shuffle=False) for x in tqdm(recon_loader, desc=f"Progress", colour="#FF0000"): x = x.to(device) recon_x, _, _ = vae(x) save_image(torch.cat([x, recon_x]), f"{latent_recon_images}/recon{epoch}.png",nrow=16) def save_latents(vae, device): os.makedirs(latent_dir, exist_ok=True) with torch.no_grad(): for file_name in tqdm(os.listdir(resized_img_dir), desc=f"Progress"): # if file_name not in os.listdir(latent_dir): transform = T.Compose([T.ToTensor()]) img = transform(Image.open(f"{resized_img_dir}/{file_name}").convert("RGB")) img = img.unsqueeze(0).to(device) z = vae.encode_img_to_latent(img) torch.save(z.squeeze(0), f"{latent_dir}/{file_name[:-4]}.pt") # (16, 16, 16) def compute_mean_std_over_latents(device): all_sum = torch.tensor(0.0).to(device) all_sq_sum = torch.tensor(0.0).to(device) count = 0 with torch.no_grad(): for file_name in tqdm(os.listdir(latent_dir)[:], desc=f"Progress"): z = torch.load(f"{latent_dir}/{file_name}", map_location="cpu", weights_only=False).to(device) all_sum += z.sum() all_sq_sum += (z ** 2).sum() count += z.numel() mean = all_sum / count var = (all_sq_sum / count) - mean ** 2 std = torch.sqrt(var) print("Mean over latent:", mean.item()) # -0.0074263461865484715 print("STD over latent:", std.item()) # 0.8829671740531921 print("Latent Scale:", (1 / std).item()) # 1.1325448751449585 return mean.item(), std.item() def scale_latents(latents_scale, device): os.makedirs(latent_scaled_dir, exist_ok=True) with torch.no_grad(): for file_name in tqdm(os.listdir(latent_dir), desc=f"Progress", colour=tqdm_colors[5]): z = torch.load(f"{latent_dir}/{file_name}", map_location="cpu", weights_only=False).to(device) z *= latents_scale # if file_name not in os.listdir(latent_dir): torch.save(z, f"{latent_scaled_dir}/{file_name}") # (16, 16, 16) def normalize_latents(mu, std, device): os.makedirs(latent_norm_dir, exist_ok=True) with torch.no_grad(): for file_name in tqdm(os.listdir(latent_dir), desc=f"Progress", colour=tqdm_colors[5]): z = torch.load(f"{latent_dir}/{file_name}", map_location="cpu", weights_only=False).to(device) z = (z - mu) / std # if file_name not in os.listdir(latent_dir): torch.save(z, f"{latent_norm_dir}/{file_name}") # (16, 16, 16) if __name__ == "__main__": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") vae = VAE().to(device) path = f"{vae_checkpoint_dir}/{vae_weight}" checkpoint = torch.load(path, map_location=device, weights_only=False) vae.load_state_dict(checkpoint["vae"]) vae.eval() # sample_latents(vae) save_latents(vae, device) mu, std = compute_mean_std_over_latents(device) scale_latents(1 / std, device) # normalize_latents(mu, std, device)