uhdessai commited on
Commit
4d47529
·
verified ·
1 Parent(s): fcea2d8

Upload projector.py

Browse files
Files changed (1) hide show
  1. projector.py +213 -0
projector.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Project given image to the latent space of pretrained network pickle."""
10
+
11
+ import copy
12
+ import os
13
+ from time import perf_counter
14
+
15
+ import click
16
+ import imageio
17
+ import numpy as np
18
+ import PIL.Image
19
+ import torch
20
+ import torch.nn.functional as F
21
+
22
+ import dnnlib
23
+ import legacy
24
+
25
+ def project(
26
+ G,
27
+ target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
28
+ *,
29
+ num_steps = 1000,
30
+ w_avg_samples = 10000,
31
+ initial_learning_rate = 0.1,
32
+ initial_noise_factor = 0.05,
33
+ lr_rampdown_length = 0.25,
34
+ lr_rampup_length = 0.05,
35
+ noise_ramp_length = 0.75,
36
+ regularize_noise_weight = 1e5,
37
+ verbose = False,
38
+ device: torch.device
39
+ ):
40
+ assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
41
+
42
+ def logprint(*args):
43
+ if verbose:
44
+ print(*args)
45
+
46
+ G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore
47
+
48
+ # Compute w stats.
49
+ logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
50
+ z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
51
+ w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C]
52
+ w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]
53
+ w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]
54
+ w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
55
+
56
+ # Setup noise inputs.
57
+ noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }
58
+
59
+ # Load VGG16 feature detector.
60
+ url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
61
+ with dnnlib.util.open_url(url) as f:
62
+ vgg16 = torch.jit.load(f).eval().to(device)
63
+
64
+ # Features for target image.
65
+ target_images = target.unsqueeze(0).to(device).to(torch.float32)
66
+ if target_images.shape[2] > 256:
67
+ target_images = F.interpolate(target_images, size=(256, 256), mode='area')
68
+ target_features = vgg16(target_images, resize_images=False, return_lpips=True)
69
+
70
+ w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
71
+ w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device)
72
+ optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)
73
+
74
+ # Init noise.
75
+ for buf in noise_bufs.values():
76
+ buf[:] = torch.randn_like(buf)
77
+ buf.requires_grad = True
78
+
79
+ for step in range(num_steps):
80
+ # Learning rate schedule.
81
+ t = step / num_steps
82
+ w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
83
+ lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
84
+ lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
85
+ lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
86
+ lr = initial_learning_rate * lr_ramp
87
+ for param_group in optimizer.param_groups:
88
+ param_group['lr'] = lr
89
+
90
+ # Synth images from opt_w.
91
+ w_noise = torch.randn_like(w_opt) * w_noise_scale
92
+ ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1])
93
+ synth_images = G.synthesis(ws, noise_mode='const')
94
+
95
+ # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
96
+ synth_images = (synth_images + 1) * (255/2)
97
+ if synth_images.shape[2] > 256:
98
+ synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
99
+
100
+ # Features for synth images.
101
+ synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
102
+ dist = (target_features - synth_features).square().sum()
103
+
104
+ # Noise regularization.
105
+ reg_loss = 0.0
106
+ for v in noise_bufs.values():
107
+ noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d()
108
+ while True:
109
+ reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2
110
+ reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2
111
+ if noise.shape[2] <= 8:
112
+ break
113
+ noise = F.avg_pool2d(noise, kernel_size=2)
114
+ loss = dist + reg_loss * regularize_noise_weight
115
+
116
+ # Step
117
+ optimizer.zero_grad(set_to_none=True)
118
+ loss.backward()
119
+ optimizer.step()
120
+ logprint(f'step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')
121
+
122
+ # Save projected W for each optimization step.
123
+ w_out[step] = w_opt.detach()[0]
124
+
125
+ # Normalize noise.
126
+ with torch.no_grad():
127
+ for buf in noise_bufs.values():
128
+ buf -= buf.mean()
129
+ buf *= buf.square().mean().rsqrt()
130
+
131
+ return w_out.repeat([1, G.mapping.num_ws, 1])
132
+
133
+ #----------------------------------------------------------------------------
134
+
135
+ @click.command()
136
+ @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
137
+ @click.option('--target', 'target_fname', help='Target image file to project to', required=True, metavar='FILE')
138
+ @click.option('--num-steps', help='Number of optimization steps', type=int, default=1000, show_default=True)
139
+ @click.option('--seed', help='Random seed', type=int, default=303, show_default=True)
140
+ @click.option('--save-video', help='Save an mp4 video of optimization progress', type=bool, default=True, show_default=True)
141
+ @click.option('--outdir', help='Where to save the output images', required=True, metavar='DIR')
142
+ def run_projection(
143
+ network_pkl: str,
144
+ target_fname: str,
145
+ outdir: str,
146
+ save_video: bool,
147
+ seed: int,
148
+ num_steps: int
149
+ ):
150
+ """Project given image to the latent space of pretrained network pickle.
151
+
152
+ Examples:
153
+
154
+ \b
155
+ python projector.py --outdir=out --target=~/mytargetimg.png \\
156
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
157
+ """
158
+ np.random.seed(seed)
159
+ torch.manual_seed(seed)
160
+
161
+ # Load networks.
162
+ print('Loading networks from "%s"...' % network_pkl)
163
+ device = torch.device('cuda')
164
+ with dnnlib.util.open_url(network_pkl) as fp:
165
+ G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore
166
+
167
+ # Load target image.
168
+ target_pil = PIL.Image.open(target_fname).convert('RGB')
169
+ w, h = target_pil.size
170
+ s = min(w, h)
171
+ target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
172
+ target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS)
173
+ target_uint8 = np.array(target_pil, dtype=np.uint8)
174
+
175
+ # Optimize projection.
176
+ start_time = perf_counter()
177
+ projected_w_steps = project(
178
+ G,
179
+ target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callable
180
+ num_steps=num_steps,
181
+ device=device,
182
+ verbose=True
183
+ )
184
+ print (f'Elapsed: {(perf_counter()-start_time):.1f} s')
185
+
186
+ # Render debug output: optional video and projected image and W vector.
187
+ os.makedirs(outdir, exist_ok=True)
188
+ if save_video:
189
+ print("Skipping video saving as per configuration.")
190
+ # video = imageio.get_writer(f'{outdir}/proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
191
+ # print (f'Saving optimization progress video "{outdir}/proj.mp4"')
192
+ # for projected_w in projected_w_steps:
193
+ # synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
194
+ # synth_image = (synth_image + 1) * (255/2)
195
+ # synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
196
+ # video.append_data(np.concatenate([target_uint8, synth_image], axis=1))
197
+ # video.close()
198
+
199
+ # Save final projected frame and W vector.
200
+ target_pil.save(f'{outdir}/target.png')
201
+ projected_w = projected_w_steps[-1]
202
+ synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
203
+ synth_image = (synth_image + 1) * (255/2)
204
+ synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
205
+ PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/proj.png')
206
+ np.savez(f'{outdir}/projected_w.npz', w=projected_w.unsqueeze(0).cpu().numpy())
207
+
208
+ #----------------------------------------------------------------------------
209
+
210
+ if __name__ == "__main__":
211
+ run_projection() # pylint: disable=no-value-for-parameter
212
+
213
+ #----------------------------------------------------------------------------