Delete Junk
Browse files
inference/.ipynb_checkpoints/icgan_colab-checkpoint.ipynb
DELETED
|
@@ -1,707 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "markdown",
|
| 5 |
-
"id": "bb75a2b9",
|
| 6 |
-
"metadata": {},
|
| 7 |
-
"source": [
|
| 8 |
-
"Copyright (c) Facebook, Inc. and its affiliates.\n",
|
| 9 |
-
"All rights reserved.\n",
|
| 10 |
-
"\n",
|
| 11 |
-
"This source code is licensed under the license found in the\n",
|
| 12 |
-
"LICENSE file in the root directory of this source tree."
|
| 13 |
-
]
|
| 14 |
-
},
|
| 15 |
-
{
|
| 16 |
-
"cell_type": "markdown",
|
| 17 |
-
"id": "81a8ddb6",
|
| 18 |
-
"metadata": {},
|
| 19 |
-
"source": [
|
| 20 |
-
"# IC-GAN\n",
|
| 21 |
-
"\n",
|
| 22 |
-
"\n",
|
| 23 |
-
"Official Colab notebook from the paper <b>\"Instance-Conditioned GAN\"</b> by Arantxa Casanova, Marlene Careil, Jakob Verbeek, Michał Drożdżal, Adriana Romero-Soriano.\n",
|
| 24 |
-
"\n",
|
| 25 |
-
"This Colab provides the code to generate images with IC-GAN, with the option of further guiding the generation with captions (CLIP). \n",
|
| 26 |
-
"\n",
|
| 27 |
-
"Based on the Colab [WanderClip](https://j.mp/wanderclip) by Eyal Gruss [@eyaler](https://twitter.com/eyaler) [eyalgruss.com](https://eyalgruss.com)\n",
|
| 28 |
-
"\n",
|
| 29 |
-
"Using the work from [our repository](https://github.com/facebookresearch/ic_gan)\n",
|
| 30 |
-
"\n",
|
| 31 |
-
"https://github.com/openai/CLIP, Copyright (c) 2021 OpenAI\n",
|
| 32 |
-
"\n",
|
| 33 |
-
"https://github.com/huggingface/pytorch-pretrained-BigGAN, Copyright (c) 2019 Thomas Wolf\n",
|
| 34 |
-
"\n",
|
| 35 |
-
"\n"
|
| 36 |
-
]
|
| 37 |
-
},
|
| 38 |
-
{
|
| 39 |
-
"cell_type": "code",
|
| 40 |
-
"execution_count": null,
|
| 41 |
-
"id": "9442671e",
|
| 42 |
-
"metadata": {},
|
| 43 |
-
"outputs": [],
|
| 44 |
-
"source": [
|
| 45 |
-
"#@title Restart after running this cell!\n",
|
| 46 |
-
"\n",
|
| 47 |
-
"!nvidia-smi -L\n",
|
| 48 |
-
"\n",
|
| 49 |
-
"import subprocess\n",
|
| 50 |
-
"\n",
|
| 51 |
-
"CUDA_version = [s for s in subprocess.check_output([\"nvcc\", \"--version\"]).decode(\"UTF-8\").split(\", \") if s.startswith(\"release\")][0].split(\" \")[-1]\n",
|
| 52 |
-
"print(\"CUDA version:\", CUDA_version)\n",
|
| 53 |
-
"\n",
|
| 54 |
-
"if CUDA_version == \"10.1\":\n",
|
| 55 |
-
" torch_version_suffix = \"+cu101\"\n",
|
| 56 |
-
"elif CUDA_version == \"10.2\":\n",
|
| 57 |
-
" torch_version_suffix = \"+cu102\"\n",
|
| 58 |
-
"else:\n",
|
| 59 |
-
" torch_version_suffix = \"+cu111\"\n",
|
| 60 |
-
"\n",
|
| 61 |
-
"!pip install torch==1.8.0{torch_version_suffix} torchvision==0.8.2{torch_version_suffix} -f https://download.pytorch.org/whl/torch_stable.html ftfy regex"
|
| 62 |
-
]
|
| 63 |
-
},
|
| 64 |
-
{
|
| 65 |
-
"cell_type": "code",
|
| 66 |
-
"execution_count": null,
|
| 67 |
-
"id": "b01f51f1",
|
| 68 |
-
"metadata": {},
|
| 69 |
-
"outputs": [],
|
| 70 |
-
"source": [
|
| 71 |
-
"#@title Setup\n",
|
| 72 |
-
"!git clone https://github.com/facebookresearch/ic_gan.git\n",
|
| 73 |
-
"\n",
|
| 74 |
-
"%cd /content\n",
|
| 75 |
-
"# Uncompress required files\n",
|
| 76 |
-
"!wget https://dl.fbaipublicfiles.com/ic_gan/cc_icgan_biggan_imagenet_res256.tar.gz\n",
|
| 77 |
-
"!tar -xvf cc_icgan_biggan_imagenet_res256.tar.gz\n",
|
| 78 |
-
"!wget https://dl.fbaipublicfiles.com/ic_gan/icgan_biggan_imagenet_res256.tar.gz\n",
|
| 79 |
-
"!tar -xvf icgan_biggan_imagenet_res256.tar.gz\n",
|
| 80 |
-
"!wget https://dl.fbaipublicfiles.com/ic_gan/stored_instances.tar.gz\n",
|
| 81 |
-
"!tar -xvf stored_instances.tar.gz\n",
|
| 82 |
-
"\n",
|
| 83 |
-
"!pip install pytorch-pretrained-biggan\n",
|
| 84 |
-
"\n",
|
| 85 |
-
"!git clone --depth 1 https://github.com/openai/CLIP\n",
|
| 86 |
-
"!pip install ftfy\n",
|
| 87 |
-
"%cd /content/CLIP\n",
|
| 88 |
-
"import clip\n",
|
| 89 |
-
"last_clip_model = 'ViT-B/32'\n",
|
| 90 |
-
"perceptor, preprocess = clip.load(last_clip_model)\n",
|
| 91 |
-
"\n",
|
| 92 |
-
"import nltk\n",
|
| 93 |
-
"nltk.download('wordnet')\n",
|
| 94 |
-
"\n",
|
| 95 |
-
"!pip install cma\n"
|
| 96 |
-
]
|
| 97 |
-
},
|
| 98 |
-
{
|
| 99 |
-
"cell_type": "code",
|
| 100 |
-
"execution_count": null,
|
| 101 |
-
"id": "fa6c7629",
|
| 102 |
-
"metadata": {},
|
| 103 |
-
"outputs": [],
|
| 104 |
-
"source": [
|
| 105 |
-
"#@title Prepare functions\n",
|
| 106 |
-
"from pytorch_pretrained_biggan import BigGAN, convert_to_images, one_hot_from_names, utils\n",
|
| 107 |
-
"\n",
|
| 108 |
-
"%cd /content/ic_gan/\n",
|
| 109 |
-
"import sys\n",
|
| 110 |
-
"import os\n",
|
| 111 |
-
"sys.path[0] = '/content/ic_gan/inference'\n",
|
| 112 |
-
"sys.path.insert(1, os.path.join(sys.path[0], \"..\"))\n",
|
| 113 |
-
"import torch \n",
|
| 114 |
-
"\n",
|
| 115 |
-
"import numpy as np\n",
|
| 116 |
-
"import torch\n",
|
| 117 |
-
"import torchvision\n",
|
| 118 |
-
"import sys\n",
|
| 119 |
-
"torch.manual_seed(np.random.randint(sys.maxsize))\n",
|
| 120 |
-
"import imageio\n",
|
| 121 |
-
"from IPython.display import HTML, Image, clear_output\n",
|
| 122 |
-
"from PIL import Image as Image_PIL\n",
|
| 123 |
-
"from scipy.stats import truncnorm, dirichlet\n",
|
| 124 |
-
"from torch import nn\n",
|
| 125 |
-
"from nltk.corpus import wordnet as wn\n",
|
| 126 |
-
"from base64 import b64encode\n",
|
| 127 |
-
"from time import time\n",
|
| 128 |
-
"import cma\n",
|
| 129 |
-
"from cma.sigma_adaptation import CMAAdaptSigmaCSA, CMAAdaptSigmaTPA\n",
|
| 130 |
-
"import warnings\n",
|
| 131 |
-
"warnings.simplefilter(\"ignore\", cma.evolution_strategy.InjectionWarning)\n",
|
| 132 |
-
"import torchvision.transforms as transforms\n",
|
| 133 |
-
"import inference.utils as inference_utils\n",
|
| 134 |
-
"import data_utils.utils as data_utils\n",
|
| 135 |
-
"from BigGAN_PyTorch.BigGAN import Generator as generator\n",
|
| 136 |
-
"import sklearn.metrics\n",
|
| 137 |
-
"\n",
|
| 138 |
-
"def replace_to_inplace_relu(model): #saves memory; from https://github.com/minyoungg/pix2latent/blob/master/pix2latent/model/biggan.py\n",
|
| 139 |
-
" for child_name, child in model.named_children():\n",
|
| 140 |
-
" if isinstance(child, nn.ReLU):\n",
|
| 141 |
-
" setattr(model, child_name, nn.ReLU(inplace=False))\n",
|
| 142 |
-
" else:\n",
|
| 143 |
-
" replace_to_inplace_relu(child)\n",
|
| 144 |
-
" return\n",
|
| 145 |
-
" \n",
|
| 146 |
-
"def save(out,name=None, torch_format=True):\n",
|
| 147 |
-
" if torch_format:\n",
|
| 148 |
-
" with torch.no_grad():\n",
|
| 149 |
-
" out = out.cpu().numpy()\n",
|
| 150 |
-
" img = convert_to_images(out)[0]\n",
|
| 151 |
-
" if name:\n",
|
| 152 |
-
" imageio.imwrite(name, np.asarray(img))\n",
|
| 153 |
-
" return img\n",
|
| 154 |
-
"\n",
|
| 155 |
-
"hist = []\n",
|
| 156 |
-
"def checkin(i, best_ind, total_losses, losses, regs, out, noise=None, emb=None, probs=None):\n",
|
| 157 |
-
" global sample_num, hist\n",
|
| 158 |
-
" name = None\n",
|
| 159 |
-
" if save_every and i%save_every==0:\n",
|
| 160 |
-
" name = '/content/output/frame_%05d.jpg'%sample_num\n",
|
| 161 |
-
" pil_image = save(out, name)\n",
|
| 162 |
-
" vals0 = [sample_num, i, total_losses[best_ind], losses[best_ind], regs[best_ind], np.mean(total_losses), np.mean(losses), np.mean(regs), np.std(total_losses), np.std(losses), np.std(regs)]\n",
|
| 163 |
-
" stats = 'sample=%d iter=%d best: total=%.2f cos=%.2f reg=%.3f avg: total=%.2f cos=%.2f reg=%.3f std: total=%.2f cos=%.2f reg=%.3f'%tuple(vals0)\n",
|
| 164 |
-
" vals1 = []\n",
|
| 165 |
-
" if noise is not None:\n",
|
| 166 |
-
" vals1 = [np.mean(noise), np.std(noise)]\n",
|
| 167 |
-
" stats += ' noise: avg=%.2f std=%.3f'%tuple(vals1)\n",
|
| 168 |
-
" vals2 = []\n",
|
| 169 |
-
" if emb is not None:\n",
|
| 170 |
-
" vals2 = [emb.mean(),emb.std()]\n",
|
| 171 |
-
" stats += ' emb: avg=%.2f std=%.3f'%tuple(vals2)\n",
|
| 172 |
-
" elif probs:\n",
|
| 173 |
-
" best = probs[best_ind]\n",
|
| 174 |
-
" inds = np.argsort(best)[::-1]\n",
|
| 175 |
-
" probs = np.array(probs)\n",
|
| 176 |
-
" vals2 = [ind2name[inds[0]], best[inds[0]], ind2name[inds[1]], best[inds[1]], ind2name[inds[2]], best[inds[2]], np.sum(probs >= 0.5)/pop_size,np.sum(probs >= 0.3)/pop_size,np.sum(probs >= 0.1)/pop_size]\n",
|
| 177 |
-
" stats += ' 1st=%s(%.2f) 2nd=%s(%.2f) 3rd=%s(%.2f) components: >=0.5:%.0f, >=0.3:%.0f, >=0.1:%.0f'%tuple(vals2)\n",
|
| 178 |
-
" hist.append(vals0+vals1+vals2)\n",
|
| 179 |
-
" if show_every and i%show_every==0:\n",
|
| 180 |
-
" clear_output()\n",
|
| 181 |
-
" display(pil_image) \n",
|
| 182 |
-
" print(stats)\n",
|
| 183 |
-
" sample_num += 1\n",
|
| 184 |
-
"\n",
|
| 185 |
-
"def load_icgan(experiment_name, root_ = '/content'):\n",
|
| 186 |
-
" root = os.path.join(root_, experiment_name)\n",
|
| 187 |
-
" config = torch.load(\"%s/%s.pth\" %\n",
|
| 188 |
-
" (root, \"state_dict_best0\"))['config']\n",
|
| 189 |
-
"\n",
|
| 190 |
-
" config[\"weights_root\"] = root_\n",
|
| 191 |
-
" config[\"model_backbone\"] = 'biggan'\n",
|
| 192 |
-
" config[\"experiment_name\"] = experiment_name\n",
|
| 193 |
-
" # TODO: delete this line\n",
|
| 194 |
-
" G, config = inference_utils.load_model_inference(config)\n",
|
| 195 |
-
" G.cuda()\n",
|
| 196 |
-
" G.eval()\n",
|
| 197 |
-
" return G\n",
|
| 198 |
-
"\n",
|
| 199 |
-
"def get_output(noise_vector, input_label, input_features): \n",
|
| 200 |
-
" if stochastic_truncation: #https://arxiv.org/abs/1702.04782\n",
|
| 201 |
-
" with torch.no_grad():\n",
|
| 202 |
-
" trunc_indices = noise_vector.abs() > 2*truncation\n",
|
| 203 |
-
" size = torch.count_nonzero(trunc_indices).cpu().numpy()\n",
|
| 204 |
-
" trunc = truncnorm.rvs(-2*truncation, 2*truncation, size=(1,size)).astype(np.float32)\n",
|
| 205 |
-
" noise_vector.data[trunc_indices] = torch.tensor(trunc, requires_grad=requires_grad, device='cuda')\n",
|
| 206 |
-
" else:\n",
|
| 207 |
-
" noise_vector = noise_vector.clamp(-2*truncation, 2*truncation)\n",
|
| 208 |
-
" if input_label is not None:\n",
|
| 209 |
-
" input_label = torch.LongTensor(input_label)\n",
|
| 210 |
-
" else:\n",
|
| 211 |
-
" input_label = None\n",
|
| 212 |
-
"\n",
|
| 213 |
-
" out = model(noise_vector, input_label.cuda() if input_label is not None else None, input_features.cuda() if input_features is not None else None)\n",
|
| 214 |
-
" \n",
|
| 215 |
-
" if channels==1:\n",
|
| 216 |
-
" out = out.mean(dim=1, keepdim=True)\n",
|
| 217 |
-
" out = out.repeat(1,3,1,1)\n",
|
| 218 |
-
" return out\n",
|
| 219 |
-
"\n",
|
| 220 |
-
"def normality_loss(vec): #https://arxiv.org/abs/1903.00925\n",
|
| 221 |
-
" mu2 = vec.mean().square()\n",
|
| 222 |
-
" sigma2 = vec.var()\n",
|
| 223 |
-
" return mu2+sigma2-torch.log(sigma2)-1\n",
|
| 224 |
-
" \n",
|
| 225 |
-
"\n",
|
| 226 |
-
"def load_generative_model(gen_model, last_gen_model, experiment_name, model):\n",
|
| 227 |
-
" # Load generative model\n",
|
| 228 |
-
" if gen_model != last_gen_model:\n",
|
| 229 |
-
" model = load_icgan(experiment_name, root_ = '/content')\n",
|
| 230 |
-
" last_gen_model = gen_model\n",
|
| 231 |
-
" return model, last_gen_model\n",
|
| 232 |
-
"\n",
|
| 233 |
-
"def load_feature_extractor(gen_model, last_feature_extractor, feature_extractor):\n",
|
| 234 |
-
" # Load feature extractor to obtain instance features\n",
|
| 235 |
-
" feat_ext_name = 'classification' if gen_model == 'cc_icgan' else 'selfsupervised'\n",
|
| 236 |
-
" if last_feature_extractor != feat_ext_name:\n",
|
| 237 |
-
" if feat_ext_name == 'classification':\n",
|
| 238 |
-
" feat_ext_path = ''\n",
|
| 239 |
-
" else:\n",
|
| 240 |
-
" !curl -L -o /content/swav_pretrained.pth.tar -C - 'https://dl.fbaipublicfiles.com/deepcluster/swav_800ep_pretrain.pth.tar' \n",
|
| 241 |
-
" feat_ext_path = '/content/swav_pretrained.pth.tar'\n",
|
| 242 |
-
" last_feature_extractor = feat_ext_name\n",
|
| 243 |
-
" feature_extractor = data_utils.load_pretrained_feature_extractor(feat_ext_path, feature_extractor = feat_ext_name)\n",
|
| 244 |
-
" feature_extractor.eval()\n",
|
| 245 |
-
" return feature_extractor, last_feature_extractor\n",
|
| 246 |
-
"\n",
|
| 247 |
-
"norm_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1)\n",
|
| 248 |
-
"norm_std = torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1)\n",
|
| 249 |
-
"\n",
|
| 250 |
-
"def preprocess_input_image(input_image_path, size): \n",
|
| 251 |
-
" pil_image = Image_PIL.open(input_image_path).convert('RGB')\n",
|
| 252 |
-
" transform_list = transforms.Compose([data_utils.CenterCropLongEdge(), transforms.Resize((size,size)), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std)])\n",
|
| 253 |
-
" tensor_image = transform_list(pil_image)\n",
|
| 254 |
-
" tensor_image = torch.nn.functional.interpolate(tensor_image.unsqueeze(0), 224, mode=\"bicubic\", align_corners=True)\n",
|
| 255 |
-
" return tensor_image\n",
|
| 256 |
-
"\n",
|
| 257 |
-
"def preprocess_generated_image(image): \n",
|
| 258 |
-
" transform_list = transforms.Normalize(norm_mean, norm_std)\n",
|
| 259 |
-
" image = transform_list(image*0.5 + 0.5)\n",
|
| 260 |
-
" image = torch.nn.functional.interpolate(image, 224, mode=\"bicubic\", align_corners=True)\n",
|
| 261 |
-
" return image\n",
|
| 262 |
-
"\n",
|
| 263 |
-
"last_gen_model = None\n",
|
| 264 |
-
"last_feature_extractor = None\n",
|
| 265 |
-
"model = None\n",
|
| 266 |
-
"feature_extractor = None"
|
| 267 |
-
]
|
| 268 |
-
},
|
| 269 |
-
{
|
| 270 |
-
"cell_type": "code",
|
| 271 |
-
"execution_count": null,
|
| 272 |
-
"id": "17278e04",
|
| 273 |
-
"metadata": {},
|
| 274 |
-
"outputs": [],
|
| 275 |
-
"source": [
|
| 276 |
-
"#@title Generate images with IC-GAN!\n",
|
| 277 |
-
"#@markdown 1. Select type of IC-GAN model with **gen_model**: \"icgan\" is conditioned on an instance; \"cc_icgan\" is conditioned on both instance and a class index.\n",
|
| 278 |
-
"#@markdown 1. Select which instance to condition on, following one of the following options:\n",
|
| 279 |
-
"#@markdown 1. **input_image_instance** is the path to an input image, from either the mounted Google Drive or a manually uploaded image to \"Files\" (left part of the screen).\n",
|
| 280 |
-
"#@markdown 1. **input_feature_index** write an integer from 0 to 1000. This will change the instance conditioning and therefore the style and semantics of the generated images. This will select one of the 1000 instance features pre-selected from ImageNet using k-means.\n",
|
| 281 |
-
"#@markdown 1. For **class_index** (only valid for gen_model=\"cc_icgan\") write an integer from 0 to 1000. This will change the ImageNet class to condition on. Consult [this link](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a) for a correspondence between class name and indexes.\n",
|
| 282 |
-
"#@markdown 1. **num_samples_ranked** (default=16) indicates the number of generated images to output in a mosaic. These generated images are the ones that scored a higher cosine similarity with the conditioning instance, out of **num_samples_total** (default=160) generated samples. Increasing \"num_samples_total\" will result in higher run times, but more generated images to choose the top \"num_samples_ranked\" from, and therefore higher chance of better image quality. Reducing \"num_samples_total\" too much could result in generated images with poorer visual quality. A ratio of 10:1 (num_samples_total:num_samples_ranked) is recommended.\n",
|
| 283 |
-
"#@markdown 1. Vary **truncation** (default=0.7) from 0 to 1 to apply the [truncation trick](https://arxiv.org/abs/1809.11096). Truncation=1 will provide more diverse but possibly poorer quality images. Trucation values between 0.7 and 0.9 seem to empirically work well.\n",
|
| 284 |
-
"#@markdown 1. **seed**=0 means no seed.\n",
|
| 285 |
-
"\n",
|
| 286 |
-
"gen_model = 'icgan' #@param ['icgan', 'cc_icgan']\n",
|
| 287 |
-
"if gen_model == 'icgan': \n",
|
| 288 |
-
" experiment_name = 'icgan_biggan_imagenet_res256'\n",
|
| 289 |
-
"else:\n",
|
| 290 |
-
" experiment_name = 'cc_icgan_biggan_imagenet_res256'\n",
|
| 291 |
-
"#last_gen_model = experiment_name\n",
|
| 292 |
-
"size = '256'\n",
|
| 293 |
-
"input_image_instance = \"\"#@param {type:\"string\"}\n",
|
| 294 |
-
"input_feature_index = 3#@param {type:'integer'}\n",
|
| 295 |
-
"class_index = 538#@param {type:'integer'}\n",
|
| 296 |
-
"num_samples_ranked = 16#@param {type:'integer'}\n",
|
| 297 |
-
"num_samples_total = 160#@param {type:'integer'}\n",
|
| 298 |
-
"truncation = 0.7#@param {type:'number'}\n",
|
| 299 |
-
"stochastic_truncation = False #@param {type:'boolean'}\n",
|
| 300 |
-
"download_file = True #@param {type:'boolean'}\n",
|
| 301 |
-
"seed = 50#@param {type:'number'}\n",
|
| 302 |
-
"if seed == 0:\n",
|
| 303 |
-
" seed = None\n",
|
| 304 |
-
"noise_size = 128\n",
|
| 305 |
-
"class_size = 1000\n",
|
| 306 |
-
"channels = 3\n",
|
| 307 |
-
"batch_size = 4\n",
|
| 308 |
-
"if gen_model == 'icgan':\n",
|
| 309 |
-
" class_index = None\n",
|
| 310 |
-
"if 'biggan' in gen_model:\n",
|
| 311 |
-
" input_feature_index = None\n",
|
| 312 |
-
" input_image_instance = None\n",
|
| 313 |
-
"\n",
|
| 314 |
-
"assert(num_samples_ranked <=num_samples_total)\n",
|
| 315 |
-
"import numpy as np\n",
|
| 316 |
-
"state = None if not seed else np.random.RandomState(seed)\n",
|
| 317 |
-
"np.random.seed(seed)\n",
|
| 318 |
-
"\n",
|
| 319 |
-
"feature_extractor_name = 'classification' if gen_model == 'cc_icgan' else 'selfsupervised'\n",
|
| 320 |
-
"\n",
|
| 321 |
-
"# Load feature extractor (outlier filtering and optionally input image feature extraction)\n",
|
| 322 |
-
"feature_extractor, last_feature_extractor = load_feature_extractor(gen_model, last_feature_extractor, feature_extractor)\n",
|
| 323 |
-
"# Load features \n",
|
| 324 |
-
"if input_image_instance not in ['None', \"\"]:\n",
|
| 325 |
-
" print('Obtaining instance features from input image!')\n",
|
| 326 |
-
" input_feature_index = None\n",
|
| 327 |
-
" input_image_tensor = preprocess_input_image(input_image_instance, int(size))\n",
|
| 328 |
-
" print('Displaying instance conditioning:')\n",
|
| 329 |
-
" display(convert_to_images(((input_image_tensor*norm_std + norm_mean)-0.5) / 0.5)[0])\n",
|
| 330 |
-
" with torch.no_grad():\n",
|
| 331 |
-
" input_features, _ = feature_extractor(input_image_tensor.cuda())\n",
|
| 332 |
-
" input_features/=torch.linalg.norm(input_features,dim=-1, keepdims=True)\n",
|
| 333 |
-
"elif input_feature_index is not None:\n",
|
| 334 |
-
" print('Selecting an instance from pre-extracted vectors!')\n",
|
| 335 |
-
" input_features = np.load('/content/stored_instances/imagenet_res'+str(size)+'_rn50_'+feature_extractor_name+'_kmeans_k1000_instance_features.npy', allow_pickle=True).item()[\"instance_features\"][input_feature_index:input_feature_index+1]\n",
|
| 336 |
-
"else:\n",
|
| 337 |
-
" input_features = None\n",
|
| 338 |
-
"\n",
|
| 339 |
-
"# Load generative model\n",
|
| 340 |
-
"model, last_gen_model = load_generative_model(gen_model, last_gen_model, experiment_name, model)\n",
|
| 341 |
-
"# Prepare other variables\n",
|
| 342 |
-
"name_file = '%s_class_index%s_instance_index%s'%(gen_model, str(class_index) if class_index is not None else 'None', str(input_feature_index) if input_feature_index is not None else 'None')\n",
|
| 343 |
-
"\n",
|
| 344 |
-
"!rm -rf /content/output\n",
|
| 345 |
-
"!mkdir -p /content/output\n",
|
| 346 |
-
"\n",
|
| 347 |
-
"replace_to_inplace_relu(model)\n",
|
| 348 |
-
"ind2name = {index: wn.of2ss('%08dn'%offset).lemma_names()[0] for offset, index in utils.IMAGENET.items()}\n",
|
| 349 |
-
"\n",
|
| 350 |
-
"from google.colab import files, output\n",
|
| 351 |
-
"\n",
|
| 352 |
-
"eps = 1e-8\n",
|
| 353 |
-
"\n",
|
| 354 |
-
"# Create noise, instance and class vector\n",
|
| 355 |
-
"noise_vector = truncnorm.rvs(-2*truncation, 2*truncation, size=(num_samples_total, noise_size), random_state=state).astype(np.float32) #see https://github.com/tensorflow/hub/issues/214\n",
|
| 356 |
-
"noise_vector = torch.tensor(noise_vector, requires_grad=False, device='cuda')\n",
|
| 357 |
-
"if input_features is not None:\n",
|
| 358 |
-
" instance_vector = torch.tensor(input_features, requires_grad=False, device='cuda').repeat(num_samples_total, 1)\n",
|
| 359 |
-
"else: \n",
|
| 360 |
-
" instance_vector = None\n",
|
| 361 |
-
"if class_index is not None:\n",
|
| 362 |
-
" print('Conditioning on class: ', ind2name[class_index])\n",
|
| 363 |
-
" input_label = torch.LongTensor([class_index]*num_samples_total)\n",
|
| 364 |
-
"else:\n",
|
| 365 |
-
" input_label = None\n",
|
| 366 |
-
"if input_feature_index is not None:\n",
|
| 367 |
-
" print('Conditioning on instance with index: ', input_feature_index)\n",
|
| 368 |
-
"\n",
|
| 369 |
-
"size = int(size)\n",
|
| 370 |
-
"all_outs, all_dists = [], []\n",
|
| 371 |
-
"for i_bs in range(num_samples_total//batch_size+1):\n",
|
| 372 |
-
" start = i_bs*batch_size\n",
|
| 373 |
-
" end = min(start+batch_size, num_samples_total)\n",
|
| 374 |
-
" if start == end:\n",
|
| 375 |
-
" break\n",
|
| 376 |
-
" out = get_output(noise_vector[start:end], input_label[start:end] if input_label is not None else None, instance_vector[start:end] if instance_vector is not None else None)\n",
|
| 377 |
-
"\n",
|
| 378 |
-
" if instance_vector is not None:\n",
|
| 379 |
-
" # Get features from generated images + feature extractor\n",
|
| 380 |
-
" out_ = preprocess_generated_image(out)\n",
|
| 381 |
-
" with torch.no_grad():\n",
|
| 382 |
-
" out_features, _ = feature_extractor(out_.cuda())\n",
|
| 383 |
-
" out_features/=torch.linalg.norm(out_features,dim=-1, keepdims=True)\n",
|
| 384 |
-
" dists = sklearn.metrics.pairwise_distances(\n",
|
| 385 |
-
" out_features.cpu(), instance_vector[start:end].cpu(), metric=\"euclidean\", n_jobs=-1)\n",
|
| 386 |
-
" all_dists.append(np.diagonal(dists))\n",
|
| 387 |
-
" all_outs.append(out.detach().cpu())\n",
|
| 388 |
-
" del (out)\n",
|
| 389 |
-
"all_outs = torch.cat(all_outs)\n",
|
| 390 |
-
"all_dists = np.concatenate(all_dists)\n",
|
| 391 |
-
"\n",
|
| 392 |
-
"# Order samples by distance to conditioning feature vector and select only num_samples_ranked images\n",
|
| 393 |
-
"selected_idxs =np.argsort(all_dists)[:num_samples_ranked]\n",
|
| 394 |
-
"#print('All distances re-ordered ', np.sort(all_dists))\n",
|
| 395 |
-
"# Create figure \n",
|
| 396 |
-
"row_i, col_i, i_im = 0, 0, 0\n",
|
| 397 |
-
"all_images_mosaic = np.zeros((3,size*(int(np.sqrt(num_samples_ranked))), size*(int(np.sqrt(num_samples_ranked)))))\n",
|
| 398 |
-
"for j in selected_idxs:\n",
|
| 399 |
-
" all_images_mosaic[:,row_i*size:row_i*size+size, col_i*size:col_i*size+size] = all_outs[j]\n",
|
| 400 |
-
" if row_i == int(np.sqrt(num_samples_ranked))-1:\n",
|
| 401 |
-
" row_i = 0\n",
|
| 402 |
-
" if col_i == int(np.sqrt(num_samples_ranked))-1:\n",
|
| 403 |
-
" col_i = 0\n",
|
| 404 |
-
" else:\n",
|
| 405 |
-
" col_i +=1\n",
|
| 406 |
-
" else:\n",
|
| 407 |
-
" row_i+=1\n",
|
| 408 |
-
" i_im +=1\n",
|
| 409 |
-
"\n",
|
| 410 |
-
"name = '/content/%s_seed%i.png'%(name_file,seed if seed is not None else -1)\n",
|
| 411 |
-
"pil_image = save(all_images_mosaic[np.newaxis,...],name, torch_format=False) \n",
|
| 412 |
-
"print('Displaying generated images')\n",
|
| 413 |
-
"display(pil_image)\n",
|
| 414 |
-
"\n",
|
| 415 |
-
"if download_file:\n",
|
| 416 |
-
" files.download(name)\n",
|
| 417 |
-
"\n"
|
| 418 |
-
]
|
| 419 |
-
},
|
| 420 |
-
{
|
| 421 |
-
"cell_type": "code",
|
| 422 |
-
"execution_count": null,
|
| 423 |
-
"id": "da5ee254",
|
| 424 |
-
"metadata": {},
|
| 425 |
-
"outputs": [],
|
| 426 |
-
"source": [
|
| 427 |
-
"#@title Generate images with IC-GAN + CLIP!\n",
|
| 428 |
-
"#@markdown 1. For **prompt** OpenAI suggest to use the template \"A photo of a X.\" or \"A photo of a X, a type of Y.\" [[paper]](https://cdn.openai.com/papers/Learning_Transferable_Visual_Models_From_Natural_Language_Supervision.pdf)\n",
|
| 429 |
-
"#@markdown 1. Select type of IC-GAN model with **gen_model**: \"icgan\" is conditioned on an instance; \"cc_icgan\" is conditioned on both instance and a class index.\n",
|
| 430 |
-
"#@markdown 1. Select which instance to condition on, following one of the following options:\n",
|
| 431 |
-
"#@markdown 1. **input_image_instance** is the path to an input image, from either the mounted Google Drive or a manually uploaded image to \"Files\" (left part of the screen).\n",
|
| 432 |
-
"#@markdown 1. **input_feature_index** write an integer from 0 to 1000. This will change the instance conditioning and therefore the style and semantics of the generated images. This will select one of the 1000 instance features pre-selected from ImageNet using k-means.\n",
|
| 433 |
-
"#@markdown 1. For **class_index** (only valid for gen_model=\"cc_icgan\") write an integer from 0 to 1000. This will change the ImageNet class to condition on. Consult [this link](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a) for a correspondence between class name and indexes.\n",
|
| 434 |
-
"#@markdown 1. Vary **truncation** from 0 to 1 to apply the [truncation trick](https://arxiv.org/abs/1809.11096). Truncation=1 will provide more diverse but possibly poorer quality images. Trucation values between 0.7 and 0.9 seem to empirically work well.\n",
|
| 435 |
-
"#@markdown 4. **seed**=0 means no seed.\n",
|
| 436 |
-
"prompt = 'A dragon' #@param {type:'string'}\n",
|
| 437 |
-
"gen_model = 'icgan' #@param ['icgan', 'cc_icgan']\n",
|
| 438 |
-
"if gen_model == 'icgan': \n",
|
| 439 |
-
" experiment_name = 'icgan_biggan_imagenet_res256_nofeataug'\n",
|
| 440 |
-
"else:\n",
|
| 441 |
-
" experiment_name = 'cc_icgan_biggan_imagenet_res256_nofeataug'\n",
|
| 442 |
-
"#last_gen_model = experiment_name\n",
|
| 443 |
-
"size = '256'\n",
|
| 444 |
-
"input_image_instance = \"\"#@param {type:\"string\"}\n",
|
| 445 |
-
"\n",
|
| 446 |
-
"input_feature_index = 500#@param {type:'integer'}\n",
|
| 447 |
-
"class_index = 627 #@param {type:'integer'} (only with cc_icgan)\n",
|
| 448 |
-
"download_image = False #@param {type:'boolean'}\n",
|
| 449 |
-
"download_video = False #@param {type:'boolean'}\n",
|
| 450 |
-
"truncation = 0.7 #@param {type:'number'}\n",
|
| 451 |
-
"stochastic_truncation = False #@param {type:'boolean'}\n",
|
| 452 |
-
"optimizer = 'CMA-ES' #@param ['SGD','Adam','CMA-ES','CMA-ES + SGD interleaved','CMA-ES + Adam interleaved','CMA-ES + terminal SGD','CMA-ES + terminal Adam']\n",
|
| 453 |
-
"pop_size = 50 #@param {type:'integer'}\n",
|
| 454 |
-
"clip_model = 'ViT-B/32' #@param ['ViT-B/32','RN50','RN101','RN50x4']\n",
|
| 455 |
-
"augmentations = 64#@param {type:'integer'}\n",
|
| 456 |
-
"learning_rate = 0.1#@param {type:'number'}\n",
|
| 457 |
-
"noise_normality_loss = 0#@param {type:'number'}\n",
|
| 458 |
-
"minimum_entropy_loss = 0.0001 #@param {type:'number'}\n",
|
| 459 |
-
"total_variation_loss = 0.1 #@param {type:'number'}\n",
|
| 460 |
-
"iterations = 100#@param {type:'integer'}\n",
|
| 461 |
-
"terminal_iterations = 100#@param {type:'integer'}\n",
|
| 462 |
-
"show_every = 1 #@param {type:'integer'}\n",
|
| 463 |
-
"save_every = 1 #@param {type:'integer'}\n",
|
| 464 |
-
"fps = 2#@param {type:'number'}\n",
|
| 465 |
-
"freeze_secs = 0 #@param {type:'number'}\n",
|
| 466 |
-
"seed = 10#@param {type:'number'}\n",
|
| 467 |
-
"if seed == 0:\n",
|
| 468 |
-
" seed = None\n",
|
| 469 |
-
"\n",
|
| 470 |
-
"softmax_temp = 1\n",
|
| 471 |
-
"emb_factor = 0.067 #calculated empirically \n",
|
| 472 |
-
"loss_factor = 100\n",
|
| 473 |
-
"sigma0 = 0.5 #http://cma.gforge.inria.fr/cmaes_sourcecode_page.html#practical\n",
|
| 474 |
-
"cma_adapt = True\n",
|
| 475 |
-
"cma_diag = False\n",
|
| 476 |
-
"cma_active = True\n",
|
| 477 |
-
"cma_elitist = False\n",
|
| 478 |
-
"noise_size = 128\n",
|
| 479 |
-
"class_size = 1000\n",
|
| 480 |
-
"channels = 3\n",
|
| 481 |
-
"if gen_model == 'icgan':\n",
|
| 482 |
-
" class_index = None\n",
|
| 483 |
-
"\n",
|
| 484 |
-
"import numpy as np\n",
|
| 485 |
-
"state = None if not seed else np.random.RandomState(seed)\n",
|
| 486 |
-
"np.random.seed(seed)\n",
|
| 487 |
-
"# Load features \n",
|
| 488 |
-
"if input_image_instance not in ['None',\"\"]:\n",
|
| 489 |
-
" print('Obtaining instance features from input image!')\n",
|
| 490 |
-
" input_feature_index = None\n",
|
| 491 |
-
" feature_extractor, last_feature_extractor = load_feature_extractor(gen_model, last_feature_extractor, feature_extractor)\n",
|
| 492 |
-
" input_image_tensor = preprocess_input_image(input_image_instance, int(size))\n",
|
| 493 |
-
" input_features, _ = feature_extractor(input_image_tensor.cuda())\n",
|
| 494 |
-
" input_features/=torch.linalg.norm(input_features,dim=-1, keepdims=True)\n",
|
| 495 |
-
"elif input_feature_index is not None:\n",
|
| 496 |
-
" print('Selecting an instance from pre-extracted vectors!')\n",
|
| 497 |
-
" feature_extractor_name = 'classification' if gen_model == 'cc_icgan' else 'selfsupervised'\n",
|
| 498 |
-
" input_features = np.load('/content/stored_instances/imagenet_res'+str(size)+'_rn50_'+feature_extractor_name+'_kmeans_k1000_instance_features.npy', allow_pickle=True).item()[\"instance_features\"][input_feature_index:input_feature_index+1]\n",
|
| 499 |
-
"else:\n",
|
| 500 |
-
" input_features = None\n",
|
| 501 |
-
"\n",
|
| 502 |
-
"\n",
|
| 503 |
-
"# Load generative model\n",
|
| 504 |
-
"model, last_gen_model = load_generative_model(gen_model, last_gen_model, experiment_name, model)\n",
|
| 505 |
-
"\n",
|
| 506 |
-
"# Load CLIP model\n",
|
| 507 |
-
"if clip_model != last_clip_model:\n",
|
| 508 |
-
" perceptor, preprocess = clip.load(clip_model)\n",
|
| 509 |
-
" last_clip_model = clip_model\n",
|
| 510 |
-
"clip_res = perceptor.visual.input_resolution\n",
|
| 511 |
-
"sideX = sideY = int(size)\n",
|
| 512 |
-
"if sideX<=clip_res and sideY<=clip_res:\n",
|
| 513 |
-
" augmentations = 1\n",
|
| 514 |
-
"if 'CMA' not in optimizer:\n",
|
| 515 |
-
" pop_size = 1\n",
|
| 516 |
-
"\n",
|
| 517 |
-
"# Prepare other variables\n",
|
| 518 |
-
"name_file = '%s_%s_class_index%s_instance_index%s'%(gen_model, prompt, str(class_index) if class_index is not None else 'None', str(input_feature_index) if input_feature_index is not None else 'None')\n",
|
| 519 |
-
"requires_grad = ('SGD' in optimizer or 'Adam' in optimizer) and ('terminal' not in optimizer or terminal_iterations>0)\n",
|
| 520 |
-
"total_iterations = iterations + terminal_iterations*('terminal' in optimizer)\n",
|
| 521 |
-
"\n",
|
| 522 |
-
"!rm -rf /content/output\n",
|
| 523 |
-
"!mkdir -p /content/output\n",
|
| 524 |
-
"\n",
|
| 525 |
-
"replace_to_inplace_relu(model)\n",
|
| 526 |
-
"replace_to_inplace_relu(perceptor)\n",
|
| 527 |
-
"ind2name = {index: wn.of2ss('%08dn'%offset).lemma_names()[0] for offset, index in utils.IMAGENET.items()}\n",
|
| 528 |
-
"eps = 1e-8\n",
|
| 529 |
-
"\n",
|
| 530 |
-
"# Create noise and instance vector\n",
|
| 531 |
-
"noise_vector = truncnorm.rvs(-2*truncation, 2*truncation, size=(pop_size, noise_size), random_state=state).astype(np.float32) #see https://github.com/tensorflow/hub/issues/214\n",
|
| 532 |
-
"noise_vector = torch.tensor(noise_vector, requires_grad=requires_grad, device='cuda')\n",
|
| 533 |
-
"if input_features is not None:\n",
|
| 534 |
-
" instance_vector = torch.tensor(input_features, requires_grad=False, device='cuda')\n",
|
| 535 |
-
"else: \n",
|
| 536 |
-
" instance_vector = None\n",
|
| 537 |
-
"if class_index is not None:\n",
|
| 538 |
-
" print('Conditioning on class: ', ind2name[class_index])\n",
|
| 539 |
-
"if input_feature_index is not None:\n",
|
| 540 |
-
" print('Conditioning on instance with index: ', input_feature_index)\n",
|
| 541 |
-
"\n",
|
| 542 |
-
"# Prepare optimizer\n",
|
| 543 |
-
"if requires_grad:\n",
|
| 544 |
-
" params = [noise_vector]\n",
|
| 545 |
-
" if 'SGD' in optimizer:\n",
|
| 546 |
-
" optim = torch.optim.SGD(params, lr=learning_rate, momentum=0.9) \n",
|
| 547 |
-
" else:\n",
|
| 548 |
-
" optim = torch.optim.Adam(params, lr=learning_rate)\n",
|
| 549 |
-
"\n",
|
| 550 |
-
"def ascend_txt(i, grad_step=False, show_save=False):\n",
|
| 551 |
-
" global global_best_loss, global_best_iteration, global_best_noise_vector, global_best_class_vector\n",
|
| 552 |
-
" regs = []\n",
|
| 553 |
-
" losses = []\n",
|
| 554 |
-
" total_losses = []\n",
|
| 555 |
-
" best_loss = np.inf\n",
|
| 556 |
-
" global_reg = torch.tensor(0, device='cuda', dtype=torch.float32, requires_grad=grad_step)\n",
|
| 557 |
-
" if noise_normality_loss:\n",
|
| 558 |
-
" global_reg = global_reg+noise_normality_loss*normality_loss(noise_vector)\n",
|
| 559 |
-
" global_reg = loss_factor*global_reg \n",
|
| 560 |
-
" if grad_step:\n",
|
| 561 |
-
" global_reg.backward()\n",
|
| 562 |
-
" for j in range(pop_size):\n",
|
| 563 |
-
" p_s = []\n",
|
| 564 |
-
" out = get_output(noise_vector[j:j+1], [class_index] if class_index is not None else None, instance_vector)\n",
|
| 565 |
-
" for aug in range(augmentations):\n",
|
| 566 |
-
" if sideX<=clip_res and sideY<=clip_res or augmentations==1:\n",
|
| 567 |
-
" apper = out \n",
|
| 568 |
-
" else:\n",
|
| 569 |
-
" size = torch.randint(int(.7*sideX), int(.98*sideX), ())\n",
|
| 570 |
-
" offsetx = torch.randint(0, sideX - size, ())\n",
|
| 571 |
-
" offsety = torch.randint(0, sideX - size, ())\n",
|
| 572 |
-
" apper = out[:, :, offsetx:offsetx + size, offsety:offsety + size]\n",
|
| 573 |
-
" apper = (apper+1)/2\n",
|
| 574 |
-
" apper = nn.functional.interpolate(apper, clip_res, mode='bilinear')\n",
|
| 575 |
-
" #apper = apper.clamp(0,1)\n",
|
| 576 |
-
" p_s.append(apper)\n",
|
| 577 |
-
" into = nom(torch.cat(p_s, 0))\n",
|
| 578 |
-
" predict_clip = perceptor.encode_image(into)\n",
|
| 579 |
-
" loss = loss_factor*(1-torch.cosine_similarity(predict_clip, target_clip).mean())\n",
|
| 580 |
-
" total_loss = loss\n",
|
| 581 |
-
" regs.append(global_reg.item())\n",
|
| 582 |
-
"\n",
|
| 583 |
-
" with torch.no_grad():\n",
|
| 584 |
-
" losses.append(loss.item())\n",
|
| 585 |
-
" total_losses.append(total_loss.item()+global_reg.item())\n",
|
| 586 |
-
" if total_losses[-1]<best_loss:\n",
|
| 587 |
-
" best_loss = total_losses[-1]\n",
|
| 588 |
-
" best_ind = j\n",
|
| 589 |
-
" best_out = out\n",
|
| 590 |
-
" if best_loss < global_best_loss:\n",
|
| 591 |
-
" global_best_loss = best_loss\n",
|
| 592 |
-
" global_best_iteration = i\n",
|
| 593 |
-
" with torch.no_grad():\n",
|
| 594 |
-
" global_best_noise_vector = noise_vector[best_ind]\n",
|
| 595 |
-
" if grad_step: \n",
|
| 596 |
-
" total_loss.backward()\n",
|
| 597 |
-
"\n",
|
| 598 |
-
" if grad_step:\n",
|
| 599 |
-
" optim.step()\n",
|
| 600 |
-
" optim.zero_grad()\n",
|
| 601 |
-
"\n",
|
| 602 |
-
" if show_save and (save_every and i % save_every == 0 or show_every and i % show_every == 0):\n",
|
| 603 |
-
" noise = None\n",
|
| 604 |
-
" emb = None\n",
|
| 605 |
-
" with torch.no_grad():\n",
|
| 606 |
-
" noise = noise_vector.cpu().numpy()\n",
|
| 607 |
-
" checkin(i, best_ind, total_losses, losses, regs, best_out, noise, emb) \n",
|
| 608 |
-
" return total_losses, best_ind\n",
|
| 609 |
-
"\n",
|
| 610 |
-
"# Obtain target CLIP representation\n",
|
| 611 |
-
"tx = clip.tokenize(prompt)\n",
|
| 612 |
-
"with torch.no_grad():\n",
|
| 613 |
-
" target_clip = perceptor.encode_text(tx.cuda())\n",
|
| 614 |
-
"\n",
|
| 615 |
-
"\n",
|
| 616 |
-
"global_best_loss = np.inf\n",
|
| 617 |
-
"global_best_iteration = 0\n",
|
| 618 |
-
"global_best_noise_vector = None\n",
|
| 619 |
-
"\n",
|
| 620 |
-
"nom = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n",
|
| 621 |
-
"if 'CMA' in optimizer:\n",
|
| 622 |
-
" initial_vector = np.zeros(noise_size)\n",
|
| 623 |
-
" bounds = None\n",
|
| 624 |
-
" cma_opts = {'popsize': pop_size, 'seed': np.nan, 'AdaptSigma': cma_adapt, 'CMA_diagonal': cma_diag, 'CMA_active': cma_active, 'CMA_elitist':cma_elitist, 'bounds':bounds}\n",
|
| 625 |
-
" cmaes = cma.CMAEvolutionStrategy(initial_vector, sigma0, inopts=cma_opts)\n",
|
| 626 |
-
"\n",
|
| 627 |
-
"sample_num = 0\n",
|
| 628 |
-
"machine = !nvidia-smi -L\n",
|
| 629 |
-
"start = time()\n",
|
| 630 |
-
"\n",
|
| 631 |
-
"# Start noise vector optimization\n",
|
| 632 |
-
"for i in range(total_iterations): \n",
|
| 633 |
-
" if 'CMA' in optimizer and i<iterations:\n",
|
| 634 |
-
" with torch.no_grad():\n",
|
| 635 |
-
" cma_results = torch.tensor(cmaes.ask(), dtype=torch.float32).cuda()\n",
|
| 636 |
-
" noise_vector.data = cma_results \n",
|
| 637 |
-
" if requires_grad and ('terminal' not in optimizer or i>=iterations):\n",
|
| 638 |
-
" losses, best_ind = ascend_txt(i, grad_step=True, show_save='CMA' not in optimizer or i>=iterations)\n",
|
| 639 |
-
" assert noise_vector.requires_grad and noise_vector.is_leaf and (not optimize_class or class_vector.requires_grad and class_vector.is_leaf), (noise_vector.requires_grad, noise_vector.is_leaf, class_vector.requires_grad, class_vector.is_leaf)\n",
|
| 640 |
-
" if 'CMA' in optimizer and i<iterations:\n",
|
| 641 |
-
" with torch.no_grad():\n",
|
| 642 |
-
" losses, best_ind = ascend_txt(i, show_save=True)\n",
|
| 643 |
-
" if i<iterations-1:\n",
|
| 644 |
-
" vectors = noise_vector\n",
|
| 645 |
-
" cmaes.tell(vectors.cpu().numpy(), losses)\n",
|
| 646 |
-
" elif 'terminal' in optimizer and terminal_iterations:\n",
|
| 647 |
-
" pop_size = 1\n",
|
| 648 |
-
" noise_vector[0] = global_best_noise_vector\n",
|
| 649 |
-
" if save_every and i % save_every == 0 or show_every and i % show_every == 0:\n",
|
| 650 |
-
" print('took: %d secs (%.2f sec/iter) on %s. CUDA memory: %.1f GB'%(time()-start,(time()-start)/(i+1), machine[0], torch.cuda.max_memory_allocated()/1024**3))\n",
|
| 651 |
-
"\n",
|
| 652 |
-
"# Obtain generated image with lowest loss.\n",
|
| 653 |
-
"out = get_output(global_best_noise_vector.unsqueeze(0), [class_index] if class_index is not None else None, instance_vector)\n",
|
| 654 |
-
"name = '/content/%s_best_seed%i.png'%(name_file,seed if seed is not None else -1)\n",
|
| 655 |
-
"pil_image = save(out,name) \n",
|
| 656 |
-
"display(pil_image) \n",
|
| 657 |
-
"print('best_loss=%.2f best_iter=%d'%(global_best_loss,global_best_iteration))\n",
|
| 658 |
-
"\n",
|
| 659 |
-
"if download_image:\n",
|
| 660 |
-
" from google.colab import files, output\n",
|
| 661 |
-
" files.download(name)\n",
|
| 662 |
-
"\n",
|
| 663 |
-
"if download_video:\n",
|
| 664 |
-
" out = '\"/content/%s_seed%i.mp4\"'%(name_file, seed if seed is not None else -1)\n",
|
| 665 |
-
" file_name = '/content/%s_seed%i.mp4'%(name_file, seed if seed is not None else -1)\n",
|
| 666 |
-
"\n",
|
| 667 |
-
" with open('/content/list.txt','w') as f:\n",
|
| 668 |
-
" for i in range(sample_num):\n",
|
| 669 |
-
" f.write('file /content/output/frame_%05d.jpg\\n'%i)\n",
|
| 670 |
-
" for j in range(int(freeze_secs*fps)):\n",
|
| 671 |
-
" f.write('file /content/output/frame_%05d.jpg\\n'%i)\n",
|
| 672 |
-
" !ffmpeg -r $fps -f concat -safe 0 -i /content/list.txt -c:v libx264 -pix_fmt yuv420p -profile:v baseline -movflags +faststart -r $fps $out -y\n",
|
| 673 |
-
" with open(file_name, 'rb') as f:\n",
|
| 674 |
-
" data_url = \"data:video/mp4;base64,\" + b64encode(f.read()).decode()\n",
|
| 675 |
-
" display(HTML(\"\"\"\n",
|
| 676 |
-
" <video controls autoplay loop>\n",
|
| 677 |
-
" <source src=\"%s\" type=\"video/mp4\">\n",
|
| 678 |
-
" </video>\"\"\" % data_url))\n",
|
| 679 |
-
"\n",
|
| 680 |
-
" from google.colab import files, output\n",
|
| 681 |
-
" output.eval_js('new Audio(\"https://freesound.org/data/previews/80/80921_1022651-lq.ogg\").play()')\n",
|
| 682 |
-
" files.download(file_name)"
|
| 683 |
-
]
|
| 684 |
-
}
|
| 685 |
-
],
|
| 686 |
-
"metadata": {
|
| 687 |
-
"kernelspec": {
|
| 688 |
-
"display_name": "Python 3",
|
| 689 |
-
"language": "python",
|
| 690 |
-
"name": "python3"
|
| 691 |
-
},
|
| 692 |
-
"language_info": {
|
| 693 |
-
"codemirror_mode": {
|
| 694 |
-
"name": "ipython",
|
| 695 |
-
"version": 3
|
| 696 |
-
},
|
| 697 |
-
"file_extension": ".py",
|
| 698 |
-
"mimetype": "text/x-python",
|
| 699 |
-
"name": "python",
|
| 700 |
-
"nbconvert_exporter": "python",
|
| 701 |
-
"pygments_lexer": "ipython3",
|
| 702 |
-
"version": "3.8.10"
|
| 703 |
-
}
|
| 704 |
-
},
|
| 705 |
-
"nbformat": 4,
|
| 706 |
-
"nbformat_minor": 5
|
| 707 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|