| import sys |
| import os |
| import os.path as osp |
| import json |
| import argparse |
| import numpy as np |
| import clip |
| import torch |
| import math |
| from scipy.stats import truncnorm |
| from PIL import Image, ImageDraw |
| from torchvision.transforms import ToPILImage |
| from .config import SEMANTIC_DIPOLES_CORPORA |
|
|
|
|
| def create_exp_dir(args): |
| """Create output directory for current experiment under experiments/wip/ and save given the arguments (json) and |
| the given command (bash script). |
| |
| Experiment's directory name format: |
| ContraCLIP-<gan_type>(-{Z,W,W+})-K<num_latent_support_sets>-D<num_latent_support_dipoles>-css_beta_<css_beta> |
| -eps<min_shift_magnitude>_<max_shift_magnitude> |
| (-<nonlinear_css_beta-<css_beta>/linear/styleclip>)(-<contrastive_<temperature>/cossim>)-<max_iter>-<prompt> |
| |
| E.g.: |
| ContraCLIP_stylegan2_ffhq1024-W+-K3-D128-eps0.1_0.2-nonlinear_beta-0.75-contrastive_1.0-10000-expressions3 |
| |
| Args: |
| args (argparse.Namespace): the namespace object returned by `parse_args()` for the current run |
| |
| """ |
| exp_dir = "ContraCLIP_{}".format(args.gan) |
| if 'stylegan' in args.gan: |
| exp_dir += '-{}'.format(args.stylegan_space) |
| else: |
| exp_dir += '-Z' |
| exp_dir += "-K{}-D{}".format(len(SEMANTIC_DIPOLES_CORPORA[args.corpus]), args.num_latent_support_dipoles) |
| exp_dir += "-lss_beta_{}".format(args.lss_beta) |
| exp_dir += "-eps{}_{}".format(args.min_shift_magnitude, args.max_shift_magnitude) |
| if args.styleclip: |
| exp_dir += "-styleclip" |
| elif args.linear: |
| exp_dir += "-linear" |
| else: |
| exp_dir += "-nonlinear_css_beta_{}".format(args.css_beta) |
|
|
| exp_dir += "-{}".format(args.loss) |
| if args.loss == "contrastive": |
| exp_dir += "_{}".format(args.temperature) |
| exp_dir += "-{}".format(args.max_iter) |
| exp_dir += "-{}".format(args.corpus) |
|
|
| |
| wip_dir = osp.join("experiments", "wip", exp_dir) |
| os.makedirs(wip_dir, exist_ok=True) |
| |
| with open(osp.join(wip_dir, 'args.json'), 'w') as args_json_file: |
| json.dump(args.__dict__, args_json_file) |
|
|
| |
| with open(osp.join(wip_dir, 'command.sh'), 'w') as command_file: |
| command_file.write('#!/usr/bin/bash\n') |
| command_file.write(' '.join(sys.argv) + '\n') |
|
|
| return exp_dir |
|
|
|
|
| class PromptFeatures: |
| def __init__(self, prompt_corpus, clip_model): |
| self.prompt_corpus = prompt_corpus |
| |
| self.clip_model = clip_model |
| self.num_prompts = len(self.prompt_corpus) |
| self.prompt_features_dim = 512 |
|
|
| |
| |
| |
| |
| |
| |
| def get_prompt_features(self): |
| |
| device = next(self.clip_model.parameters()).device |
|
|
| |
| prompt_features = [ |
| self.clip_model.encode_text(clip.tokenize(self.prompt_corpus[t]).to(device)).unsqueeze(0) |
| for t in range(len(self.prompt_corpus)) |
| ] |
| return torch.cat(prompt_features, dim=0) |
|
|
|
|
| class TrainingStatTracker(object): |
| def __init__(self): |
| self.stat_tracker = {'loss': []} |
|
|
| def update(self, loss): |
| self.stat_tracker['loss'].append(float(loss)) |
|
|
| def get_means(self): |
| stat_means = dict() |
| for key, value in self.stat_tracker.items(): |
| stat_means.update({key: np.mean(value)}) |
| return stat_means |
|
|
| def flush(self): |
| for key in self.stat_tracker.keys(): |
| self.stat_tracker[key] = [] |
|
|
|
|
| def sample_z(batch_size, dim_z, truncation=None): |
| """Sample a random latent code from multi-variate standard Gaussian distribution with/without truncation. |
| |
| Args: |
| batch_size (int) : batch size (number of latent codes) |
| dim_z (int) : latent space dimensionality |
| truncation (float) : truncation parameter |
| |
| Returns: |
| z (torch.Tensor) : batch of latent codes |
| """ |
| if truncation is None or truncation == 1.0: |
| return torch.randn(batch_size, dim_z) |
| else: |
| return torch.from_numpy(truncnorm.rvs(-truncation, truncation, size=(batch_size, dim_z))).to(torch.float) |
|
|
|
|
| def tensor2image(tensor, adaptive=False): |
| tensor = tensor.squeeze(dim=0) |
| if adaptive: |
| tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) |
| return ToPILImage()((255 * tensor.cpu().detach()).to(torch.uint8)) |
| else: |
| tensor = (tensor + 1) / 2 |
| tensor.clamp(0, 1) |
| return ToPILImage()((255 * tensor.cpu().detach()).to(torch.uint8)) |
|
|
|
|
| def update_progress(msg, total, progress): |
| bar_length, status = 20, "" |
| progress = float(progress) / float(total) |
| if progress >= 1.: |
| progress, status = 1, "\r\n" |
| block = int(round(bar_length * progress)) |
| block_symbol = u"\u2588" |
| empty_symbol = u"\u2591" |
| text = "\r{}{} {:.0f}% {}".format(msg, block_symbol * block + empty_symbol * (bar_length - block), |
| round(progress * 100, 0), status) |
| sys.stdout.write(text) |
| sys.stdout.flush() |
|
|
|
|
| def update_stdout(num_lines): |
| """Update stdout by moving cursor up and erasing line for given number of lines. |
| |
| Args: |
| num_lines (int): number of lines |
| |
| """ |
| cursor_up = '\x1b[1A' |
| erase_line = '\x1b[1A' |
| for _ in range(num_lines): |
| print(cursor_up + erase_line) |
|
|
|
|
| def sec2dhms(t): |
| """Convert time into days, hours, minutes, and seconds string format. |
| |
| Args: |
| t (float): time in seconds |
| |
| Returns (string): |
| "<days> days, <hours> hours, <minutes> minutes, and <seconds> seconds" |
| |
| """ |
| day = t // (24 * 3600) |
| t = t % (24 * 3600) |
| hour = t // 3600 |
| t %= 3600 |
| minutes = t // 60 |
| t %= 60 |
| seconds = t |
| return "%02d days, %02d hours, %02d minutes, and %02d seconds" % (day, hour, minutes, seconds) |
|
|
|
|
| def get_wh(img_paths): |
| """Get width and height of images in given list of paths. Images are expected to have the same resolution. |
| |
| Args: |
| img_paths (list): list of image paths |
| |
| Returns: |
| width (int) : the common images width |
| height (int) : the common images height |
| |
| """ |
| img_widths = [] |
| img_heights = [] |
| for img in img_paths: |
| img_ = Image.open(img) |
| img_widths.append(img_.width) |
| img_heights.append(img_.height) |
|
|
| if len(set(img_widths)) == len(set(img_heights)) == 1: |
| return img_widths[0], img_heights[1] |
| else: |
| raise ValueError("Inconsistent image resolutions in {}".format(img_paths)) |
|
|
|
|
| def create_summarizing_gif(imgs_root, gif_filename, num_imgs=None, gif_size=None, gif_fps=30, gap=15, progress_bar_h=15, |
| progress_bar_color=(252, 186, 3)): |
| """Create a summarizing GIF image given an images root directory (images generated across a certain latent path) and |
| the number of images to appear as a static sequence. The resolution of the resulting GIF image will be |
| ((num_imgs + 1) * gif_size, gif_size). That is, a static sequence of `num_imgs` images will be depicted in front of |
| the animated GIF image (the latter will use all the available images in `imgs_root`). |
| |
| Args: |
| imgs_root (str) : directory of images (generated across a certain path) |
| gif_filename (str) : filename of the resulting GIF image |
| num_imgs (int) : number of images that will be used to build the static sequence before the |
| animated part of the GIF |
| gif_size (int) : height of the GIF image (its width will be equal to (num_imgs + 1) * gif_size) |
| gif_fps (int) : GIF frames per second |
| gap (int) : a gap between the static sequence and the animated path of the GIF |
| progress_bar_h (int) : height of the progress bar depicted to the bottom of the animated part of the GIF |
| image. If a non-positive number is given, progress bar will be disabled. |
| progress_bar_color (tuple) : color of the progress bar |
| |
| """ |
| |
| if not osp.isdir(imgs_root): |
| raise NotADirectoryError("Invalid directory: {}".format(imgs_root)) |
|
|
| |
| path_images = [osp.join(imgs_root, dI) for dI in os.listdir(imgs_root) if osp.isfile(osp.join(imgs_root, dI))] |
| path_images.sort() |
|
|
| |
| num_images = len(path_images) |
| if num_imgs is None: |
| num_imgs = num_images |
| elif num_imgs > num_images: |
| num_imgs = num_images |
|
|
| |
| static_imgs = [] |
| for i in range(0, len(path_images), math.ceil(len(path_images) / num_imgs)): |
| static_imgs.append(osp.join(imgs_root, '{:06}.jpg'.format(i))) |
| num_imgs = len(static_imgs) |
|
|
| |
| if gif_size is not None: |
| gif_w = gif_h = gif_size |
| else: |
| gif_w, gif_h = get_wh(static_imgs) |
|
|
| |
| static_img_pil = Image.new('RGB', size=(len(static_imgs) * gif_w, gif_h)) |
| for i in range(len(static_imgs)): |
| static_img_pil.paste(Image.open(static_imgs[i]).resize((gif_w, gif_h)), (i * gif_w, 0)) |
|
|
| |
| gif_frames = [] |
| for i in range(len(path_images)): |
| |
| gif_frame_pil = Image.new('RGB', size=((num_imgs + 1) * gif_w + gap, gif_h), color=(255, 255, 255)) |
|
|
| |
| gif_frame_pil.paste(static_img_pil, (0, 0)) |
|
|
| |
| gif_frame_pil.paste(Image.open(path_images[i]).resize((gif_w, gif_h)), (num_imgs * gif_w + gap, 0)) |
|
|
| |
| if progress_bar_h > 0: |
| gif_frame_pil_drawing = ImageDraw.Draw(gif_frame_pil) |
| progress = (i / len(path_images)) * gif_w |
| gif_frame_pil_drawing.rectangle(xy=[num_imgs * gif_w + gap, gif_h - progress_bar_h, |
| num_imgs * gif_w + gap + progress, gif_h], |
| fill=progress_bar_color) |
|
|
| |
| gif_frames.append(gif_frame_pil) |
|
|
| |
| gif_frames[0].save( |
| fp=gif_filename, |
| append_images=gif_frames[1:], |
| save_all=True, |
| optimize=False, |
| loop=0, |
| duration=1000 // gif_fps) |
|
|