Spaces:
Sleeping
Sleeping
| # import sys; sys.path.extend(['.', 'src', '/home/skoroki/StyleCLIP']) | |
| import argparse | |
| import math | |
| import os | |
| from typing import List | |
| import json | |
| import re | |
| import random | |
| import yaml | |
| import itertools | |
| import torchvision | |
| from torch import optim | |
| from PIL import Image | |
| import click | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| from omegaconf import OmegaConf | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import utils | |
| from torch import Tensor | |
| import torchvision.transforms.functional as TVF | |
| from torchvision.utils import save_image | |
| from torch import Tensor | |
| from src.deps.facial_recognition.model_irse import Backbone | |
| try: | |
| import clip | |
| except ImportError: | |
| raise ImportError( | |
| "To edit videos with CLIP, you need to install the `clip` library. " \ | |
| "Please follow the instructions in https://github.com/openai/CLIP") | |
| from src import dnnlib | |
| import legacy | |
| from src.scripts.project import save_edited_w | |
| #---------------------------------------------------------------------------- | |
| def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): | |
| lr_ramp = min(1, (1 - t) / rampdown) | |
| lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) | |
| lr_ramp = lr_ramp * min(1, t / rampup) | |
| return initial_lr * lr_ramp | |
| #---------------------------------------------------------------------------- | |
| class CLIPLoss(torch.nn.Module): | |
| """ | |
| Copy-pasted and adapted from StyleCLIP | |
| """ | |
| def __init__(self): | |
| super(CLIPLoss, self).__init__() | |
| self.model, self.preprocess = clip.load("ViT-B/32", device="cuda") | |
| #self.upsample = torch.nn.Upsample(scale_factor=7) | |
| #self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32) | |
| def forward(self, image, text): | |
| #image = self.avg_pool(self.upsample(image)) | |
| #print('shape', image.shape, text.shape) | |
| image = F.interpolate(image, size=(224, 224), mode='area') | |
| similarity = 1 - self.model(image, text)[0] / 100 | |
| similarity = similarity.diag() | |
| return similarity | |
| #---------------------------------------------------------------------------- | |
| class IDLoss(nn.Module): | |
| """ | |
| Copy-pasted from StyleCLIP | |
| """ | |
| def __init__(self): | |
| super(IDLoss, self).__init__() | |
| self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') | |
| with dnnlib.util.open_url(Backbone.WEIGHTS_URL, verbose=True) as f: | |
| ir_se50_weights = torch.load(f) | |
| self.facenet.load_state_dict(ir_se50_weights) | |
| self.pool = torch.nn.AdaptiveAvgPool2d((256, 256)) | |
| self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) | |
| self.facenet.eval() | |
| self.facenet.cuda() | |
| def extract_feats(self, x): | |
| if x.shape[2] != 256: | |
| x = self.pool(x) | |
| x = x[:, :, 35:223, 32:220] # Crop interesting region | |
| x = self.face_pool(x) | |
| x_feats = self.facenet(x) | |
| return x_feats | |
| def forward(self, y_hat, y): | |
| n_samples = y.shape[0] | |
| y_feats = self.extract_feats(y) # Otherwise use the feature from there | |
| y_hat_feats = self.extract_feats(y_hat) | |
| y_feats = y_feats.detach() | |
| loss = 0 | |
| for i in range(n_samples): | |
| diff_target = y_hat_feats[i].dot(y_feats[i]) | |
| loss += 1 - diff_target | |
| return loss / n_samples | |
| #---------------------------------------------------------------------------- | |
| def run_edit_optimization( | |
| _sentinel=None, | |
| G: nn.Module=None, | |
| w_orig: Tensor=None, | |
| descriptions: List[str]=None, | |
| # ckpt: float="stylegan2-ffhq-config-f.pt", | |
| lr: float=0.1, | |
| num_steps: int=40, | |
| l2_lambda: float=0.001, | |
| id_lambda: float=0.005, | |
| # latent_path: float=latent_path, | |
| # truncation: float=0.7, | |
| # save_intermediate_image_every: float=1 if create_video else 20, | |
| # results_dir: float="results", | |
| mask: float=None, | |
| mask_lambda: float=0.0, | |
| verbose: bool=False, | |
| ) -> Tensor: | |
| assert _sentinel is None | |
| # text_inputs = torch.cat([clip.tokenize(d) for d in descriptions]).to(device) | |
| num_prompts = len(descriptions) | |
| num_images = len(w_orig) | |
| device = w_orig.device | |
| text_inputs = clip.tokenize(descriptions).to(device) # [num_prompts, 77] | |
| text_inputs = text_inputs.repeat_interleave(len(w_orig), dim=0) # [num_prompts * num_images, 77] | |
| c = torch.zeros(num_prompts * num_images, 0, device=device) | |
| ts = torch.zeros(num_prompts * num_images, 1, device=device) | |
| w_orig = w_orig.repeat(num_prompts, 1, 1) # [num_prompts * num_images, num_ws, w_dim] | |
| with torch.no_grad(): | |
| img_orig = G.synthesis(ws=w_orig, c=c, t=ts) # [num_prompts * num_images, 3, c, h, w] | |
| w = w_orig.detach().clone() # [num_prompts * num_images, num_ws, w_dim] | |
| w.requires_grad = True | |
| if mask_lambda > 0: | |
| target_image = img_orig * (1 - mask) # [num_prompts * num_images, 3, c, h, w] | |
| #target_image = img_orig[:, :, -128:, :128] | |
| target_image = (target_image * 0.5 + 0.5) * 255.0 # [num_prompts * num_images, 3, c, h, w] | |
| if target_image.shape[2] > 256: | |
| target_image = F.interpolate(target_image, size=(256, 256), mode='area') | |
| target_features = vgg16(target_image, resize_images=False, return_lpips=True) | |
| #dist = (target_features - synth_features).square().sum() | |
| else: | |
| target_features = None | |
| clip_loss = CLIPLoss() | |
| id_loss = IDLoss() | |
| optimizer = optim.Adam([w], lr=lr) | |
| if verbose: | |
| pbar = tqdm(range(num_steps)) | |
| else: | |
| pbar = range(num_steps) | |
| for curr_iter in pbar: | |
| curr_lr = get_lr(curr_iter / num_steps, lr) | |
| # optimizer.param_groups[0]["lr"] = lr | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] = curr_lr | |
| #img_gen, _ = g_ema([latent], input_is_latent=True, randomize_noise=False, input_is_stylespace=work_in_stylespace) | |
| img_gen = G.synthesis(ws=w, c=c, t=ts) # [num_prompts * num_images, 3, c, h, w] | |
| if mask_lambda > 0: | |
| raise NotImplementedError | |
| synth_image = img_gen * (1 - mask) | |
| #synth_image = img_gen[:, :, -128:, :128] | |
| synth_image = (synth_image * 0.5 + 0.5) * 255.0 | |
| if synth_image.shape[2] > 256: | |
| synth_image = F.interpolate(synth_image, size=(256, 256), mode='area') | |
| synth_features = vgg16(synth_image, resize_images=False, return_lpips=True) | |
| mask_loss = (target_features - synth_features).square().sum() | |
| else: | |
| mask_loss = 0 | |
| if not mask is None: | |
| img_gen = img_gen * mask.unsqueeze(0) # [num_prompts * num_images, 3, c, h, w] | |
| c_loss = clip_loss(img_gen, text_inputs) # [num_prompts * num_images] | |
| if id_lambda > 0: | |
| i_loss = id_loss(img_gen, img_orig) | |
| else: | |
| i_loss = 0 | |
| l2_loss = ((w_orig - w) ** 2) # [1] | |
| loss = c_loss.sum() + l2_lambda * l2_loss.sum() + id_lambda * i_loss + mask_lambda * mask_loss | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| if verbose: | |
| pbar.set_description((f"loss: {loss.item():.4f};")) | |
| final_result = torch.stack([img_orig, img_gen]) # [2, num_prompts * num_images, c, h, w] | |
| return final_result, w | |
| # x, new_w = main(args) | |
| # pair = torch.cat([img for img in x], dim=2) | |
| # TVF.to_pil_image((pair.cpu().detach() * 0.5 + 0.5).clamp(0, 1)) | |
| #---------------------------------------------------------------------------- | |
| # @click.option('--truncation_psi', type=float, help='Truncation psi', default=1.0, show_default=True) | |
| # @click.option('--noise_mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) | |
| # @click.option('--same_motion_codes', type=bool, help='Should we use the same motion codes for all videos?', default=False, show_default=True) | |
| # l2_lambda=0.001, | |
| # id_lambda=0.005, | |
| # l2_lambda=0.0005, | |
| # id_lambda=0.0, | |
| def main( | |
| ctx: click.Context, | |
| network_pkl: str, | |
| networks_dir: str, | |
| seed: int, | |
| w_dir: str, | |
| results_dir: str, | |
| truncation_psi: float, | |
| num_w: int, | |
| # save_as_mp4: bool, | |
| # video_len: int, | |
| # fps: int, | |
| # as_grids: bool, | |
| zero_periods: bool, | |
| num_weights_to_slice: int, | |
| num_steps: int, | |
| stack_samples: bool, | |
| l2_lambda: float, | |
| id_lambda: float, | |
| lr: float, | |
| prompts: str, | |
| mask_lambda: float, | |
| use_id_lambda: bool, | |
| ): | |
| if network_pkl is None: | |
| output_regex = "^network-snapshot-\d{6}.pkl$" | |
| ckpt_regex = re.compile("^network-snapshot-\d{6}.pkl$") | |
| # ckpts = sorted([f for f in os.listdir(networks_dir) if ckpt_regex.match(f)]) | |
| # network_pkl = os.path.join(networks_dir, ckpts[-1]) | |
| metrics_file = os.path.join(networks_dir, 'metric-fvd2048_16f.jsonl') | |
| with open(metrics_file, 'r') as f: | |
| snapshot_metrics_vals = [json.loads(line) for line in f.read().splitlines()] | |
| best_snapshot = sorted(snapshot_metrics_vals, key=lambda m: m['results']['fvd2048_16f'])[0] | |
| network_pkl = os.path.join(networks_dir, best_snapshot['snapshot_pkl']) | |
| print(f'Using checkpoint: {network_pkl} with FVD16 of', best_snapshot['results']['fvd2048_16f']) | |
| # Selecting a checkpoint with the best score | |
| else: | |
| assert networks_dir is None, "Cant have both parameters: network_pkl and networks_dir" | |
| print('Loading networks from "%s"...' % network_pkl, end='') | |
| device = torch.device('cuda') | |
| with dnnlib.util.open_url(network_pkl) as f: | |
| G = legacy.load_network_pkl(f)['G_ema'].to(device).eval() # type: ignore | |
| print('Loaded!') | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if zero_periods: | |
| G.synthesis.motion_encoder.time_encoder.periods_predictor.weight.data.zero_() | |
| if num_weights_to_slice > 0: | |
| G.synthesis.motion_encoder.time_encoder.weights[:, -num_weights_to_slice:] = 0.0 | |
| # description = "Bright sunny sky and mountains far away" | |
| # experiment_type = 'edit' #@param ['edit', 'free_generation'] | |
| # mask = torch.zeros(3, 256, 256, device=device) | |
| # mask[:, :, 64+32 : 128+32] = 1.0 | |
| # mask[:, :-128, :] = 1.0 | |
| # mask[:, :, 128:] = 1.0 | |
| if w_dir is None: | |
| print('Sampling new w') | |
| z = torch.randn(num_w, G.z_dim, device=device) | |
| c = torch.zeros(len(z), G.c_dim, device=device) | |
| w_orig = G.mapping(z=z, c=c, truncation_psi=truncation_psi) | |
| os.makedirs(results_dir, exist_ok=True) | |
| torch.save(w_orig.cpu(), f'{results_dir}_w_orig.pt') | |
| w_save_dir = os.path.join(results_dir, 'w_edit') | |
| samples_save_dir = os.path.join(results_dir, 'edited_samples') | |
| else: | |
| w_paths = sorted([os.path.join(w_dir, f) for f in os.listdir(w_dir) if f.endswith('_w.pt')]) | |
| w_names = [os.path.basename(f) for f in w_paths] | |
| w_orig = [torch.load(f) for f in w_paths] | |
| w_orig = torch.stack(w_orig).to(device) # [num_images, num_ws, w_dim] | |
| w_save_dir = f'{w_dir}_edited_w' | |
| samples_save_dir = f'{w_dir}_edited_samples' | |
| os.makedirs(w_save_dir, exist_ok=True) | |
| os.makedirs(samples_save_dir, exist_ok=True) | |
| print(f'Loading prompts from file: {prompts}') | |
| with open(prompts, 'r') as f: | |
| descs_dict = yaml.load(f) | |
| edit_names, descriptions = list(zip(*descs_dict.items())) | |
| edit_names = edit_names | |
| descriptions = descriptions | |
| del id_lambda, num_steps, l2_lambda | |
| l2_lambdas = [1000000.0, 0.0025, 0.001, 0.00025, 0.0005, 0.0001] | |
| if use_id_lambda: | |
| id_lambdas = [0.005, 0.0025, 0.001, 0.00025, 0.0005, 0.0001, 0.0] | |
| else: | |
| id_lambdas = [0.0] | |
| all_num_steps = [40] | |
| for curr_edit_name, curr_prompt in zip(edit_names, descriptions): | |
| all_images = [] | |
| all_w_edited = [] | |
| for l2_lambda, id_lambda, num_steps in tqdm(list(itertools.product(l2_lambdas, id_lambdas, all_num_steps)), desc=f'Performing HPO for {curr_edit_name}'): | |
| final_image, w_edited = run_edit_optimization( | |
| G=G, | |
| w_orig=w_orig, | |
| descriptions=[curr_prompt], | |
| # ckpt="stylegan2-ffhq-config-f.pt", | |
| lr=lr, | |
| num_steps=num_steps, | |
| l2_lambda=l2_lambda, | |
| id_lambda=id_lambda, | |
| mask_lambda=mask_lambda, | |
| verbose=False, | |
| # latent_path=latent_path, | |
| # truncation=0.7, | |
| # mask=None, | |
| # mask_lambda=0.1, | |
| ) | |
| all_images.extend((final_image[1].cpu() * 0.5 + 0.5).clamp(0, 1)) | |
| all_w_edited.append({ | |
| "w_edit": w_edited.cpu(), | |
| "l2_lambda": l2_lambda, | |
| "id_lambda": id_lambda, | |
| "num_steps": num_steps, | |
| "prompt": curr_prompt, | |
| "edit_name": curr_edit_name, | |
| }) | |
| # img_names = [f'{w_name}_{edit_name}' for edit_name in edit_names for w_name in w_names] | |
| # save_edited_w( | |
| # G=G, | |
| # w_outdir = f'{w_dir}_edited', | |
| # samples_outdir = f'{w_dir}_projected_samples', | |
| # img_names=img_names, | |
| # stack_samples=stack_samples, | |
| # all_w = w_edited, | |
| # all_motion_z = None, | |
| # stacked_samples_out_path = f'{w_dir}_edited_samples.png' | |
| # ) | |
| torch.save(all_w_edited, f"{w_save_dir}/{curr_edit_name}_w.pt") | |
| grid = utils.make_grid(torch.stack(all_images), nrow=len(w_orig)) | |
| print('savig intp', f"{samples_save_dir}/{curr_edit_name}.png") | |
| save_image(grid, f"{samples_save_dir}/{curr_edit_name}.png") | |
| print('Done!') | |
| #---------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| main() # pylint: disable=no-value-for-parameter | |
| #---------------------------------------------------------------------------- | |