Spaces:
Sleeping
Sleeping
| import sys, os | |
| sys.path.insert(0, os.path.dirname(__file__)) | |
| import torch | |
| from vae import VAE | |
| from config import * | |
| from PIL import Image | |
| from torchvision.utils import save_image | |
| from dataloader import ImageDataset | |
| from tqdm import tqdm | |
| from torch.utils.data import DataLoader | |
| import torchvision.transforms as T | |
| def sample_latents(vae): | |
| os.makedirs(latent_recon_images, exist_ok=True) | |
| with torch.no_grad(): | |
| z = torch.randn(16, 4, vae_latent_dim, vae_latent_dim).to(device) | |
| samples = vae.decoder_conv(z) | |
| save_image(samples, f"{latent_recon_images}/random_samples.png", nrow=4) | |
| def reconstruct(vae, device, epoch=""): | |
| os.makedirs(latent_recon_images, exist_ok=True) | |
| with torch.no_grad(): | |
| image_paths = [os.path.join(resized_img_dir, path) for path in sorted(os.listdir(resized_img_dir))[:16] if path.endswith(".jpg")] | |
| recon_loader = DataLoader(ImageDataset(image_paths), batch_size=16, shuffle=False) | |
| for x in tqdm(recon_loader, desc=f"Progress", colour="#FF0000"): | |
| x = x.to(device) | |
| recon_x, _, _ = vae(x) | |
| save_image(torch.cat([x, recon_x]), f"{latent_recon_images}/recon{epoch}.png",nrow=16) | |
| def save_latents(vae, device): | |
| os.makedirs(latent_dir, exist_ok=True) | |
| with torch.no_grad(): | |
| for file_name in tqdm(os.listdir(resized_img_dir), desc=f"Progress"): | |
| # if file_name not in os.listdir(latent_dir): | |
| transform = T.Compose([T.ToTensor()]) | |
| img = transform(Image.open(f"{resized_img_dir}/{file_name}").convert("RGB")) | |
| img = img.unsqueeze(0).to(device) | |
| z = vae.encode_img_to_latent(img) | |
| torch.save(z.squeeze(0), f"{latent_dir}/{file_name[:-4]}.pt") # (16, 16, 16) | |
| def compute_mean_std_over_latents(device): | |
| all_sum = torch.tensor(0.0).to(device) | |
| all_sq_sum = torch.tensor(0.0).to(device) | |
| count = 0 | |
| with torch.no_grad(): | |
| for file_name in tqdm(os.listdir(latent_dir)[:], desc=f"Progress"): | |
| z = torch.load(f"{latent_dir}/{file_name}", map_location="cpu", weights_only=False).to(device) | |
| all_sum += z.sum() | |
| all_sq_sum += (z ** 2).sum() | |
| count += z.numel() | |
| mean = all_sum / count | |
| var = (all_sq_sum / count) - mean ** 2 | |
| std = torch.sqrt(var) | |
| print("Mean over latent:", mean.item()) # -0.0074263461865484715 | |
| print("STD over latent:", std.item()) # 0.8829671740531921 | |
| print("Latent Scale:", (1 / std).item()) # 1.1325448751449585 | |
| return mean.item(), std.item() | |
| def scale_latents(latents_scale, device): | |
| os.makedirs(latent_scaled_dir, exist_ok=True) | |
| with torch.no_grad(): | |
| for file_name in tqdm(os.listdir(latent_dir), desc=f"Progress", colour=tqdm_colors[5]): | |
| z = torch.load(f"{latent_dir}/{file_name}", map_location="cpu", weights_only=False).to(device) | |
| z *= latents_scale | |
| # if file_name not in os.listdir(latent_dir): | |
| torch.save(z, f"{latent_scaled_dir}/{file_name}") # (16, 16, 16) | |
| def normalize_latents(mu, std, device): | |
| os.makedirs(latent_norm_dir, exist_ok=True) | |
| with torch.no_grad(): | |
| for file_name in tqdm(os.listdir(latent_dir), desc=f"Progress", colour=tqdm_colors[5]): | |
| z = torch.load(f"{latent_dir}/{file_name}", map_location="cpu", weights_only=False).to(device) | |
| z = (z - mu) / std | |
| # if file_name not in os.listdir(latent_dir): | |
| torch.save(z, f"{latent_norm_dir}/{file_name}") # (16, 16, 16) | |
| if __name__ == "__main__": | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| vae = VAE().to(device) | |
| path = f"{vae_checkpoint_dir}/{vae_weight}" | |
| checkpoint = torch.load(path, map_location=device, weights_only=False) | |
| vae.load_state_dict(checkpoint["vae"]) | |
| vae.eval() | |
| # sample_latents(vae) | |
| save_latents(vae, device) | |
| mu, std = compute_mean_std_over_latents(device) | |
| scale_latents(1 / std, device) | |
| # normalize_latents(mu, std, device) |