Spaces:
Sleeping
Sleeping
Upload projector.py
Browse files- projector.py +213 -0
projector.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
"""Project given image to the latent space of pretrained network pickle."""
|
| 10 |
+
|
| 11 |
+
import copy
|
| 12 |
+
import os
|
| 13 |
+
from time import perf_counter
|
| 14 |
+
|
| 15 |
+
import click
|
| 16 |
+
import imageio
|
| 17 |
+
import numpy as np
|
| 18 |
+
import PIL.Image
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
|
| 22 |
+
import dnnlib
|
| 23 |
+
import legacy
|
| 24 |
+
|
| 25 |
+
def project(
|
| 26 |
+
G,
|
| 27 |
+
target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
|
| 28 |
+
*,
|
| 29 |
+
num_steps = 1000,
|
| 30 |
+
w_avg_samples = 10000,
|
| 31 |
+
initial_learning_rate = 0.1,
|
| 32 |
+
initial_noise_factor = 0.05,
|
| 33 |
+
lr_rampdown_length = 0.25,
|
| 34 |
+
lr_rampup_length = 0.05,
|
| 35 |
+
noise_ramp_length = 0.75,
|
| 36 |
+
regularize_noise_weight = 1e5,
|
| 37 |
+
verbose = False,
|
| 38 |
+
device: torch.device
|
| 39 |
+
):
|
| 40 |
+
assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
|
| 41 |
+
|
| 42 |
+
def logprint(*args):
|
| 43 |
+
if verbose:
|
| 44 |
+
print(*args)
|
| 45 |
+
|
| 46 |
+
G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore
|
| 47 |
+
|
| 48 |
+
# Compute w stats.
|
| 49 |
+
logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
|
| 50 |
+
z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
|
| 51 |
+
w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C]
|
| 52 |
+
w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]
|
| 53 |
+
w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]
|
| 54 |
+
w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
|
| 55 |
+
|
| 56 |
+
# Setup noise inputs.
|
| 57 |
+
noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }
|
| 58 |
+
|
| 59 |
+
# Load VGG16 feature detector.
|
| 60 |
+
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
|
| 61 |
+
with dnnlib.util.open_url(url) as f:
|
| 62 |
+
vgg16 = torch.jit.load(f).eval().to(device)
|
| 63 |
+
|
| 64 |
+
# Features for target image.
|
| 65 |
+
target_images = target.unsqueeze(0).to(device).to(torch.float32)
|
| 66 |
+
if target_images.shape[2] > 256:
|
| 67 |
+
target_images = F.interpolate(target_images, size=(256, 256), mode='area')
|
| 68 |
+
target_features = vgg16(target_images, resize_images=False, return_lpips=True)
|
| 69 |
+
|
| 70 |
+
w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
|
| 71 |
+
w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device)
|
| 72 |
+
optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)
|
| 73 |
+
|
| 74 |
+
# Init noise.
|
| 75 |
+
for buf in noise_bufs.values():
|
| 76 |
+
buf[:] = torch.randn_like(buf)
|
| 77 |
+
buf.requires_grad = True
|
| 78 |
+
|
| 79 |
+
for step in range(num_steps):
|
| 80 |
+
# Learning rate schedule.
|
| 81 |
+
t = step / num_steps
|
| 82 |
+
w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
|
| 83 |
+
lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
|
| 84 |
+
lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
|
| 85 |
+
lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
|
| 86 |
+
lr = initial_learning_rate * lr_ramp
|
| 87 |
+
for param_group in optimizer.param_groups:
|
| 88 |
+
param_group['lr'] = lr
|
| 89 |
+
|
| 90 |
+
# Synth images from opt_w.
|
| 91 |
+
w_noise = torch.randn_like(w_opt) * w_noise_scale
|
| 92 |
+
ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1])
|
| 93 |
+
synth_images = G.synthesis(ws, noise_mode='const')
|
| 94 |
+
|
| 95 |
+
# Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
|
| 96 |
+
synth_images = (synth_images + 1) * (255/2)
|
| 97 |
+
if synth_images.shape[2] > 256:
|
| 98 |
+
synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
|
| 99 |
+
|
| 100 |
+
# Features for synth images.
|
| 101 |
+
synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
|
| 102 |
+
dist = (target_features - synth_features).square().sum()
|
| 103 |
+
|
| 104 |
+
# Noise regularization.
|
| 105 |
+
reg_loss = 0.0
|
| 106 |
+
for v in noise_bufs.values():
|
| 107 |
+
noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d()
|
| 108 |
+
while True:
|
| 109 |
+
reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2
|
| 110 |
+
reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2
|
| 111 |
+
if noise.shape[2] <= 8:
|
| 112 |
+
break
|
| 113 |
+
noise = F.avg_pool2d(noise, kernel_size=2)
|
| 114 |
+
loss = dist + reg_loss * regularize_noise_weight
|
| 115 |
+
|
| 116 |
+
# Step
|
| 117 |
+
optimizer.zero_grad(set_to_none=True)
|
| 118 |
+
loss.backward()
|
| 119 |
+
optimizer.step()
|
| 120 |
+
logprint(f'step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')
|
| 121 |
+
|
| 122 |
+
# Save projected W for each optimization step.
|
| 123 |
+
w_out[step] = w_opt.detach()[0]
|
| 124 |
+
|
| 125 |
+
# Normalize noise.
|
| 126 |
+
with torch.no_grad():
|
| 127 |
+
for buf in noise_bufs.values():
|
| 128 |
+
buf -= buf.mean()
|
| 129 |
+
buf *= buf.square().mean().rsqrt()
|
| 130 |
+
|
| 131 |
+
return w_out.repeat([1, G.mapping.num_ws, 1])
|
| 132 |
+
|
| 133 |
+
#----------------------------------------------------------------------------
|
| 134 |
+
|
| 135 |
+
@click.command()
|
| 136 |
+
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
|
| 137 |
+
@click.option('--target', 'target_fname', help='Target image file to project to', required=True, metavar='FILE')
|
| 138 |
+
@click.option('--num-steps', help='Number of optimization steps', type=int, default=1000, show_default=True)
|
| 139 |
+
@click.option('--seed', help='Random seed', type=int, default=303, show_default=True)
|
| 140 |
+
@click.option('--save-video', help='Save an mp4 video of optimization progress', type=bool, default=True, show_default=True)
|
| 141 |
+
@click.option('--outdir', help='Where to save the output images', required=True, metavar='DIR')
|
| 142 |
+
def run_projection(
|
| 143 |
+
network_pkl: str,
|
| 144 |
+
target_fname: str,
|
| 145 |
+
outdir: str,
|
| 146 |
+
save_video: bool,
|
| 147 |
+
seed: int,
|
| 148 |
+
num_steps: int
|
| 149 |
+
):
|
| 150 |
+
"""Project given image to the latent space of pretrained network pickle.
|
| 151 |
+
|
| 152 |
+
Examples:
|
| 153 |
+
|
| 154 |
+
\b
|
| 155 |
+
python projector.py --outdir=out --target=~/mytargetimg.png \\
|
| 156 |
+
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
|
| 157 |
+
"""
|
| 158 |
+
np.random.seed(seed)
|
| 159 |
+
torch.manual_seed(seed)
|
| 160 |
+
|
| 161 |
+
# Load networks.
|
| 162 |
+
print('Loading networks from "%s"...' % network_pkl)
|
| 163 |
+
device = torch.device('cuda')
|
| 164 |
+
with dnnlib.util.open_url(network_pkl) as fp:
|
| 165 |
+
G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore
|
| 166 |
+
|
| 167 |
+
# Load target image.
|
| 168 |
+
target_pil = PIL.Image.open(target_fname).convert('RGB')
|
| 169 |
+
w, h = target_pil.size
|
| 170 |
+
s = min(w, h)
|
| 171 |
+
target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
|
| 172 |
+
target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS)
|
| 173 |
+
target_uint8 = np.array(target_pil, dtype=np.uint8)
|
| 174 |
+
|
| 175 |
+
# Optimize projection.
|
| 176 |
+
start_time = perf_counter()
|
| 177 |
+
projected_w_steps = project(
|
| 178 |
+
G,
|
| 179 |
+
target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callable
|
| 180 |
+
num_steps=num_steps,
|
| 181 |
+
device=device,
|
| 182 |
+
verbose=True
|
| 183 |
+
)
|
| 184 |
+
print (f'Elapsed: {(perf_counter()-start_time):.1f} s')
|
| 185 |
+
|
| 186 |
+
# Render debug output: optional video and projected image and W vector.
|
| 187 |
+
os.makedirs(outdir, exist_ok=True)
|
| 188 |
+
if save_video:
|
| 189 |
+
print("Skipping video saving as per configuration.")
|
| 190 |
+
# video = imageio.get_writer(f'{outdir}/proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
|
| 191 |
+
# print (f'Saving optimization progress video "{outdir}/proj.mp4"')
|
| 192 |
+
# for projected_w in projected_w_steps:
|
| 193 |
+
# synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
|
| 194 |
+
# synth_image = (synth_image + 1) * (255/2)
|
| 195 |
+
# synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
|
| 196 |
+
# video.append_data(np.concatenate([target_uint8, synth_image], axis=1))
|
| 197 |
+
# video.close()
|
| 198 |
+
|
| 199 |
+
# Save final projected frame and W vector.
|
| 200 |
+
target_pil.save(f'{outdir}/target.png')
|
| 201 |
+
projected_w = projected_w_steps[-1]
|
| 202 |
+
synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
|
| 203 |
+
synth_image = (synth_image + 1) * (255/2)
|
| 204 |
+
synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
|
| 205 |
+
PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/proj.png')
|
| 206 |
+
np.savez(f'{outdir}/projected_w.npz', w=projected_w.unsqueeze(0).cpu().numpy())
|
| 207 |
+
|
| 208 |
+
#----------------------------------------------------------------------------
|
| 209 |
+
|
| 210 |
+
if __name__ == "__main__":
|
| 211 |
+
run_projection() # pylint: disable=no-value-for-parameter
|
| 212 |
+
|
| 213 |
+
#----------------------------------------------------------------------------
|