Spaces:
Paused
Paused
| import copy | |
| import os | |
| import random | |
| import urllib.request | |
| import torch | |
| import torch.nn.functional as FF | |
| import torch.optim | |
| from torchvision import utils | |
| from tqdm import tqdm | |
| from stylegan2.model import Generator | |
| class DownloadProgressBar(tqdm): | |
| def update_to(self, b=1, bsize=1, tsize=None): | |
| if tsize is not None: | |
| self.total = tsize | |
| self.update(b * bsize - self.n) | |
| def get_path(base_path): | |
| BASE_DIR = os.path.join('checkpoints') | |
| save_path = os.path.join(BASE_DIR, base_path) | |
| if not os.path.exists(save_path): | |
| url = f"https://huggingface.co/aaronb/StyleGAN2/resolve/main/{base_path}" | |
| print(f'{base_path} not found') | |
| print('Try to download from huggingface: ', url) | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| download_url(url, save_path) | |
| print('Downloaded to ', save_path) | |
| return save_path | |
| def download_url(url, output_path): | |
| with DownloadProgressBar(unit='B', unit_scale=True, | |
| miniters=1, desc=url.split('/')[-1]) as t: | |
| urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to) | |
| class CustomGenerator(Generator): | |
| def prepare( | |
| self, | |
| styles, | |
| inject_index=None, | |
| truncation=1, | |
| truncation_latent=None, | |
| input_is_latent=False, | |
| noise=None, | |
| randomize_noise=True, | |
| ): | |
| if not input_is_latent: | |
| styles = [self.style(s) for s in styles] | |
| if noise is None: | |
| if randomize_noise: | |
| noise = [None] * self.num_layers | |
| else: | |
| noise = [ | |
| getattr(self.noises, f"noise_{i}") for i in range(self.num_layers) | |
| ] | |
| if truncation < 1: | |
| style_t = [] | |
| for style in styles: | |
| style_t.append( | |
| truncation_latent + truncation * (style - truncation_latent) | |
| ) | |
| styles = style_t | |
| if len(styles) < 2: | |
| inject_index = self.n_latent | |
| if styles[0].ndim < 3: | |
| latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) | |
| else: | |
| latent = styles[0] | |
| else: | |
| if inject_index is None: | |
| inject_index = random.randint(1, self.n_latent - 1) | |
| latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) | |
| latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) | |
| latent = torch.cat([latent, latent2], 1) | |
| return latent, noise | |
| def generate( | |
| self, | |
| latent, | |
| noise, | |
| ): | |
| out = self.input(latent) | |
| out = self.conv1(out, latent[:, 0], noise=noise[0]) | |
| skip = self.to_rgb1(out, latent[:, 1]) | |
| i = 1 | |
| for conv1, conv2, noise1, noise2, to_rgb in zip( | |
| self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs | |
| ): | |
| out = conv1(out, latent[:, i], noise=noise1) | |
| out = conv2(out, latent[:, i + 1], noise=noise2) | |
| skip = to_rgb(out, latent[:, i + 2], skip) | |
| if out.shape[-1] == 256: F = out | |
| i += 2 | |
| image = skip | |
| F = FF.interpolate(F, image.shape[-2:], mode='bilinear') | |
| return image, F | |
| def stylegan2( | |
| size=512, | |
| channel_multiplier=2, | |
| latent=512, | |
| n_mlp=8, | |
| ckpt='stylegan2-ffhq-config-f.pt' | |
| ): | |
| g_ema = CustomGenerator(size, latent, n_mlp, channel_multiplier=channel_multiplier) | |
| checkpoint = torch.load(get_path(ckpt)) | |
| g_ema.load_state_dict(checkpoint["g_ema"], strict=False) | |
| g_ema.requires_grad_(False) | |
| g_ema.eval() | |
| return g_ema | |
| def bilinear_interpolate_torch(im, y, x): | |
| """ | |
| im : B,C,H,W | |
| y : 1,numPoints -- pixel location y float | |
| x : 1,numPOints -- pixel location y float | |
| """ | |
| x0 = torch.floor(x).long() | |
| x1 = x0 + 1 | |
| y0 = torch.floor(y).long() | |
| y1 = y0 + 1 | |
| wa = (x1.float() - x) * (y1.float() - y) | |
| wb = (x1.float() - x) * (y - y0.float()) | |
| wc = (x - x0.float()) * (y1.float() - y) | |
| wd = (x - x0.float()) * (y - y0.float()) | |
| # Instead of clamp | |
| x1 = x1 - torch.floor(x1 / im.shape[3]).int() | |
| y1 = y1 - torch.floor(y1 / im.shape[2]).int() | |
| Ia = im[:, :, y0, x0] | |
| Ib = im[:, :, y1, x0] | |
| Ic = im[:, :, y0, x1] | |
| Id = im[:, :, y1, x1] | |
| return Ia * wa + Ib * wb + Ic * wc + Id * wd | |
| def drag_gan(g_ema, latent: torch.Tensor, noise, F, handle_points, target_points, mask, max_iters=1000): | |
| handle_points0 = copy.deepcopy(handle_points) | |
| n = len(handle_points) | |
| r1, r2, lam, d = 3, 12, 20, 1 | |
| def neighbor(x, y, d): | |
| points = [] | |
| for i in range(x - d, x + d): | |
| for j in range(y - d, y + d): | |
| points.append(torch.tensor([i, j]).float().cuda()) | |
| return points | |
| F0 = F.detach().clone() | |
| latent_trainable = latent[:, :6, :].detach().clone().requires_grad_(True) | |
| latent_untrainable = latent[:, 6:, :].detach().clone().requires_grad_(False) | |
| optimizer = torch.optim.Adam([latent_trainable], lr=2e-3) | |
| for iter in range(max_iters): | |
| for s in range(1): | |
| optimizer.zero_grad() | |
| latent = torch.cat([latent_trainable, latent_untrainable], dim=1) | |
| sample2, F2 = g_ema.generate(latent, noise) | |
| # motion supervision | |
| loss = 0 | |
| for i in range(n): | |
| pi, ti = handle_points[i], target_points[i] | |
| di = (ti - pi) / torch.sum((ti - pi)**2) | |
| for qi in neighbor(int(pi[0]), int(pi[1]), r1): | |
| # f1 = F[..., int(qi[0]), int(qi[1])] | |
| # f2 = F2[..., int(qi[0] + di[0]), int(qi[1] + di[1])] | |
| f1 = bilinear_interpolate_torch(F2, qi[0], qi[1]).detach() | |
| f2 = bilinear_interpolate_torch(F2, qi[0] + di[0], qi[1] + di[1]) | |
| loss += FF.l1_loss(f2, f1) | |
| loss += ((F2 - F0) * (1 - mask)).abs().mean() * lam | |
| loss.backward() | |
| optimizer.step() | |
| # point tracking | |
| with torch.no_grad(): | |
| sample2, F2 = g_ema.generate(latent, noise) | |
| for i in range(n): | |
| pi = handle_points0[i] | |
| # f = F0[..., int(pi[0]), int(pi[1])] | |
| f0 = bilinear_interpolate_torch(F0, pi[0], pi[1]) | |
| minv = 1e9 | |
| minx = 1e9 | |
| miny = 1e9 | |
| for qi in neighbor(int(handle_points[i][0]), int(handle_points[i][1]), r2): | |
| # f2 = F2[..., int(qi[0]), int(qi[1])] | |
| try: | |
| f2 = bilinear_interpolate_torch(F2, qi[0], qi[1]) | |
| except: | |
| import ipdb | |
| ipdb.set_trace() | |
| v = torch.norm(f2 - f0, p=1) | |
| if v < minv: | |
| minv = v | |
| minx = int(qi[0]) | |
| miny = int(qi[1]) | |
| handle_points[i][0] = minx | |
| handle_points[i][1] = miny | |
| F = F2.detach().clone() | |
| if iter % 1 == 0: | |
| print(iter, loss.item(), handle_points, target_points) | |
| # p = handle_points[0].int() | |
| # sample2[0, :, p[0] - 5:p[0] + 5, p[1] - 5:p[1] + 5] = sample2[0, :, p[0] - 5:p[0] + 5, p[1] - 5:p[1] + 5] * 0 | |
| # t = target_points[0].int() | |
| # sample2[0, :, t[0] - 5:t[0] + 5, t[1] - 5:t[1] + 5] = sample2[0, :, t[0] - 5:t[0] + 5, t[1] - 5:t[1] + 5] * 255 | |
| # sample2[0, :, 210, 134] = sample2[0, :, 210, 134] * 0 | |
| # utils.save_image(sample2, "test2.png", normalize=True, range=(-1, 1)) | |
| yield sample2, latent, F2, handle_points | |