Spaces:
Paused
Paused
| import torch | |
| from torch import optim | |
| from torch.nn import functional as FF | |
| from torchvision import transforms | |
| from PIL import Image | |
| from tqdm import tqdm | |
| import dataclasses | |
| from .lpips import util | |
| def noise_regularize(noises): | |
| loss = 0 | |
| for noise in noises: | |
| size = noise.shape[2] | |
| while True: | |
| loss = ( | |
| loss | |
| + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2) | |
| + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2) | |
| ) | |
| if size <= 8: | |
| break | |
| noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2]) | |
| noise = noise.mean([3, 5]) | |
| size //= 2 | |
| return loss | |
| def noise_normalize_(noises): | |
| for noise in noises: | |
| mean = noise.mean() | |
| std = noise.std() | |
| noise.data.add_(-mean).div_(std) | |
| def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): | |
| lr_ramp = min(1, (1 - t) / rampdown) | |
| lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) | |
| lr_ramp = lr_ramp * min(1, t / rampup) | |
| return initial_lr * lr_ramp | |
| def latent_noise(latent, strength): | |
| noise = torch.randn_like(latent) * strength | |
| return latent + noise | |
| def make_image(tensor): | |
| return ( | |
| tensor.detach() | |
| .clamp_(min=-1, max=1) | |
| .add(1) | |
| .div_(2) | |
| .mul(255) | |
| .type(torch.uint8) | |
| .permute(0, 2, 3, 1) | |
| .to("cpu") | |
| .numpy() | |
| ) | |
| class InverseConfig: | |
| lr_warmup = 0.05 | |
| lr_decay = 0.25 | |
| lr = 0.1 | |
| noise = 0.05 | |
| noise_decay = 0.75 | |
| step = 1000 | |
| noise_regularize = 1e5 | |
| mse = 0 | |
| w_plus = False, | |
| def inverse_image( | |
| g_ema, | |
| image, | |
| image_size=256, | |
| config=InverseConfig() | |
| ): | |
| device = "cuda" | |
| args = config | |
| n_mean_latent = 10000 | |
| resize = min(image_size, 256) | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize(resize), | |
| transforms.CenterCrop(resize), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
| ] | |
| ) | |
| imgs = [] | |
| img = transform(image) | |
| imgs.append(img) | |
| imgs = torch.stack(imgs, 0).to(device) | |
| with torch.no_grad(): | |
| noise_sample = torch.randn(n_mean_latent, 512, device=device) | |
| latent_out = g_ema.style(noise_sample) | |
| latent_mean = latent_out.mean(0) | |
| latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5 | |
| percept = util.PerceptualLoss( | |
| model="net-lin", net="vgg", use_gpu=device.startswith("cuda") | |
| ) | |
| noises_single = g_ema.make_noise() | |
| noises = [] | |
| for noise in noises_single: | |
| noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_()) | |
| latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1) | |
| if args.w_plus: | |
| latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1) | |
| latent_in.requires_grad = True | |
| for noise in noises: | |
| noise.requires_grad = True | |
| optimizer = optim.Adam([latent_in] + noises, lr=args.lr) | |
| pbar = tqdm(range(args.step)) | |
| latent_path = [] | |
| for i in pbar: | |
| t = i / args.step | |
| lr = get_lr(t, args.lr) | |
| optimizer.param_groups[0]["lr"] = lr | |
| noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_decay) ** 2 | |
| latent_n = latent_noise(latent_in, noise_strength.item()) | |
| latent, noise = g_ema.prepare([latent_n], input_is_latent=True, noise=noises) | |
| img_gen, F = g_ema.generate(latent, noise) | |
| batch, channel, height, width = img_gen.shape | |
| if height > 256: | |
| factor = height // 256 | |
| img_gen = img_gen.reshape( | |
| batch, channel, height // factor, factor, width // factor, factor | |
| ) | |
| img_gen = img_gen.mean([3, 5]) | |
| p_loss = percept(img_gen, imgs).sum() | |
| n_loss = noise_regularize(noises) | |
| mse_loss = FF.mse_loss(img_gen, imgs) | |
| loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| noise_normalize_(noises) | |
| if (i + 1) % 100 == 0: | |
| latent_path.append(latent_in.detach().clone()) | |
| pbar.set_description( | |
| ( | |
| f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};" | |
| f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}" | |
| ) | |
| ) | |
| latent, noise = g_ema.prepare([latent_path[-1]], input_is_latent=True, noise=noises) | |
| img_gen, F = g_ema.generate(latent, noise) | |
| img_ar = make_image(img_gen) | |
| i = 0 | |
| noise_single = [] | |
| for noise in noises: | |
| noise_single.append(noise[i: i + 1]) | |
| result = { | |
| "latent": latent, | |
| "noise": noise_single, | |
| 'F': F, | |
| "sample": img_gen, | |
| } | |
| pil_img = Image.fromarray(img_ar[i]) | |
| pil_img.save('project.png') | |
| return result |