qpqpqpqpqpqp commited on
Commit
bea59ef
·
verified ·
1 Parent(s): b858b00

Delete Junk

Browse files
inference/.ipynb_checkpoints/icgan_colab-checkpoint.ipynb DELETED
@@ -1,707 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "bb75a2b9",
6
- "metadata": {},
7
- "source": [
8
- "Copyright (c) Facebook, Inc. and its affiliates.\n",
9
- "All rights reserved.\n",
10
- "\n",
11
- "This source code is licensed under the license found in the\n",
12
- "LICENSE file in the root directory of this source tree."
13
- ]
14
- },
15
- {
16
- "cell_type": "markdown",
17
- "id": "81a8ddb6",
18
- "metadata": {},
19
- "source": [
20
- "# IC-GAN\n",
21
- "\n",
22
- "\n",
23
- "Official Colab notebook from the paper <b>\"Instance-Conditioned GAN\"</b> by Arantxa Casanova, Marlene Careil, Jakob Verbeek, Michał Drożdżal, Adriana Romero-Soriano.\n",
24
- "\n",
25
- "This Colab provides the code to generate images with IC-GAN, with the option of further guiding the generation with captions (CLIP). \n",
26
- "\n",
27
- "Based on the Colab [WanderClip](https://j.mp/wanderclip) by Eyal Gruss [@eyaler](https://twitter.com/eyaler) [eyalgruss.com](https://eyalgruss.com)\n",
28
- "\n",
29
- "Using the work from [our repository](https://github.com/facebookresearch/ic_gan)\n",
30
- "\n",
31
- "https://github.com/openai/CLIP, Copyright (c) 2021 OpenAI\n",
32
- "\n",
33
- "https://github.com/huggingface/pytorch-pretrained-BigGAN, Copyright (c) 2019 Thomas Wolf\n",
34
- "\n",
35
- "\n"
36
- ]
37
- },
38
- {
39
- "cell_type": "code",
40
- "execution_count": null,
41
- "id": "9442671e",
42
- "metadata": {},
43
- "outputs": [],
44
- "source": [
45
- "#@title Restart after running this cell!\n",
46
- "\n",
47
- "!nvidia-smi -L\n",
48
- "\n",
49
- "import subprocess\n",
50
- "\n",
51
- "CUDA_version = [s for s in subprocess.check_output([\"nvcc\", \"--version\"]).decode(\"UTF-8\").split(\", \") if s.startswith(\"release\")][0].split(\" \")[-1]\n",
52
- "print(\"CUDA version:\", CUDA_version)\n",
53
- "\n",
54
- "if CUDA_version == \"10.1\":\n",
55
- " torch_version_suffix = \"+cu101\"\n",
56
- "elif CUDA_version == \"10.2\":\n",
57
- " torch_version_suffix = \"+cu102\"\n",
58
- "else:\n",
59
- " torch_version_suffix = \"+cu111\"\n",
60
- "\n",
61
- "!pip install torch==1.8.0{torch_version_suffix} torchvision==0.8.2{torch_version_suffix} -f https://download.pytorch.org/whl/torch_stable.html ftfy regex"
62
- ]
63
- },
64
- {
65
- "cell_type": "code",
66
- "execution_count": null,
67
- "id": "b01f51f1",
68
- "metadata": {},
69
- "outputs": [],
70
- "source": [
71
- "#@title Setup\n",
72
- "!git clone https://github.com/facebookresearch/ic_gan.git\n",
73
- "\n",
74
- "%cd /content\n",
75
- "# Uncompress required files\n",
76
- "!wget https://dl.fbaipublicfiles.com/ic_gan/cc_icgan_biggan_imagenet_res256.tar.gz\n",
77
- "!tar -xvf cc_icgan_biggan_imagenet_res256.tar.gz\n",
78
- "!wget https://dl.fbaipublicfiles.com/ic_gan/icgan_biggan_imagenet_res256.tar.gz\n",
79
- "!tar -xvf icgan_biggan_imagenet_res256.tar.gz\n",
80
- "!wget https://dl.fbaipublicfiles.com/ic_gan/stored_instances.tar.gz\n",
81
- "!tar -xvf stored_instances.tar.gz\n",
82
- "\n",
83
- "!pip install pytorch-pretrained-biggan\n",
84
- "\n",
85
- "!git clone --depth 1 https://github.com/openai/CLIP\n",
86
- "!pip install ftfy\n",
87
- "%cd /content/CLIP\n",
88
- "import clip\n",
89
- "last_clip_model = 'ViT-B/32'\n",
90
- "perceptor, preprocess = clip.load(last_clip_model)\n",
91
- "\n",
92
- "import nltk\n",
93
- "nltk.download('wordnet')\n",
94
- "\n",
95
- "!pip install cma\n"
96
- ]
97
- },
98
- {
99
- "cell_type": "code",
100
- "execution_count": null,
101
- "id": "fa6c7629",
102
- "metadata": {},
103
- "outputs": [],
104
- "source": [
105
- "#@title Prepare functions\n",
106
- "from pytorch_pretrained_biggan import BigGAN, convert_to_images, one_hot_from_names, utils\n",
107
- "\n",
108
- "%cd /content/ic_gan/\n",
109
- "import sys\n",
110
- "import os\n",
111
- "sys.path[0] = '/content/ic_gan/inference'\n",
112
- "sys.path.insert(1, os.path.join(sys.path[0], \"..\"))\n",
113
- "import torch \n",
114
- "\n",
115
- "import numpy as np\n",
116
- "import torch\n",
117
- "import torchvision\n",
118
- "import sys\n",
119
- "torch.manual_seed(np.random.randint(sys.maxsize))\n",
120
- "import imageio\n",
121
- "from IPython.display import HTML, Image, clear_output\n",
122
- "from PIL import Image as Image_PIL\n",
123
- "from scipy.stats import truncnorm, dirichlet\n",
124
- "from torch import nn\n",
125
- "from nltk.corpus import wordnet as wn\n",
126
- "from base64 import b64encode\n",
127
- "from time import time\n",
128
- "import cma\n",
129
- "from cma.sigma_adaptation import CMAAdaptSigmaCSA, CMAAdaptSigmaTPA\n",
130
- "import warnings\n",
131
- "warnings.simplefilter(\"ignore\", cma.evolution_strategy.InjectionWarning)\n",
132
- "import torchvision.transforms as transforms\n",
133
- "import inference.utils as inference_utils\n",
134
- "import data_utils.utils as data_utils\n",
135
- "from BigGAN_PyTorch.BigGAN import Generator as generator\n",
136
- "import sklearn.metrics\n",
137
- "\n",
138
- "def replace_to_inplace_relu(model): #saves memory; from https://github.com/minyoungg/pix2latent/blob/master/pix2latent/model/biggan.py\n",
139
- " for child_name, child in model.named_children():\n",
140
- " if isinstance(child, nn.ReLU):\n",
141
- " setattr(model, child_name, nn.ReLU(inplace=False))\n",
142
- " else:\n",
143
- " replace_to_inplace_relu(child)\n",
144
- " return\n",
145
- " \n",
146
- "def save(out,name=None, torch_format=True):\n",
147
- " if torch_format:\n",
148
- " with torch.no_grad():\n",
149
- " out = out.cpu().numpy()\n",
150
- " img = convert_to_images(out)[0]\n",
151
- " if name:\n",
152
- " imageio.imwrite(name, np.asarray(img))\n",
153
- " return img\n",
154
- "\n",
155
- "hist = []\n",
156
- "def checkin(i, best_ind, total_losses, losses, regs, out, noise=None, emb=None, probs=None):\n",
157
- " global sample_num, hist\n",
158
- " name = None\n",
159
- " if save_every and i%save_every==0:\n",
160
- " name = '/content/output/frame_%05d.jpg'%sample_num\n",
161
- " pil_image = save(out, name)\n",
162
- " vals0 = [sample_num, i, total_losses[best_ind], losses[best_ind], regs[best_ind], np.mean(total_losses), np.mean(losses), np.mean(regs), np.std(total_losses), np.std(losses), np.std(regs)]\n",
163
- " stats = 'sample=%d iter=%d best: total=%.2f cos=%.2f reg=%.3f avg: total=%.2f cos=%.2f reg=%.3f std: total=%.2f cos=%.2f reg=%.3f'%tuple(vals0)\n",
164
- " vals1 = []\n",
165
- " if noise is not None:\n",
166
- " vals1 = [np.mean(noise), np.std(noise)]\n",
167
- " stats += ' noise: avg=%.2f std=%.3f'%tuple(vals1)\n",
168
- " vals2 = []\n",
169
- " if emb is not None:\n",
170
- " vals2 = [emb.mean(),emb.std()]\n",
171
- " stats += ' emb: avg=%.2f std=%.3f'%tuple(vals2)\n",
172
- " elif probs:\n",
173
- " best = probs[best_ind]\n",
174
- " inds = np.argsort(best)[::-1]\n",
175
- " probs = np.array(probs)\n",
176
- " vals2 = [ind2name[inds[0]], best[inds[0]], ind2name[inds[1]], best[inds[1]], ind2name[inds[2]], best[inds[2]], np.sum(probs >= 0.5)/pop_size,np.sum(probs >= 0.3)/pop_size,np.sum(probs >= 0.1)/pop_size]\n",
177
- " stats += ' 1st=%s(%.2f) 2nd=%s(%.2f) 3rd=%s(%.2f) components: >=0.5:%.0f, >=0.3:%.0f, >=0.1:%.0f'%tuple(vals2)\n",
178
- " hist.append(vals0+vals1+vals2)\n",
179
- " if show_every and i%show_every==0:\n",
180
- " clear_output()\n",
181
- " display(pil_image) \n",
182
- " print(stats)\n",
183
- " sample_num += 1\n",
184
- "\n",
185
- "def load_icgan(experiment_name, root_ = '/content'):\n",
186
- " root = os.path.join(root_, experiment_name)\n",
187
- " config = torch.load(\"%s/%s.pth\" %\n",
188
- " (root, \"state_dict_best0\"))['config']\n",
189
- "\n",
190
- " config[\"weights_root\"] = root_\n",
191
- " config[\"model_backbone\"] = 'biggan'\n",
192
- " config[\"experiment_name\"] = experiment_name\n",
193
- " # TODO: delete this line\n",
194
- " G, config = inference_utils.load_model_inference(config)\n",
195
- " G.cuda()\n",
196
- " G.eval()\n",
197
- " return G\n",
198
- "\n",
199
- "def get_output(noise_vector, input_label, input_features): \n",
200
- " if stochastic_truncation: #https://arxiv.org/abs/1702.04782\n",
201
- " with torch.no_grad():\n",
202
- " trunc_indices = noise_vector.abs() > 2*truncation\n",
203
- " size = torch.count_nonzero(trunc_indices).cpu().numpy()\n",
204
- " trunc = truncnorm.rvs(-2*truncation, 2*truncation, size=(1,size)).astype(np.float32)\n",
205
- " noise_vector.data[trunc_indices] = torch.tensor(trunc, requires_grad=requires_grad, device='cuda')\n",
206
- " else:\n",
207
- " noise_vector = noise_vector.clamp(-2*truncation, 2*truncation)\n",
208
- " if input_label is not None:\n",
209
- " input_label = torch.LongTensor(input_label)\n",
210
- " else:\n",
211
- " input_label = None\n",
212
- "\n",
213
- " out = model(noise_vector, input_label.cuda() if input_label is not None else None, input_features.cuda() if input_features is not None else None)\n",
214
- " \n",
215
- " if channels==1:\n",
216
- " out = out.mean(dim=1, keepdim=True)\n",
217
- " out = out.repeat(1,3,1,1)\n",
218
- " return out\n",
219
- "\n",
220
- "def normality_loss(vec): #https://arxiv.org/abs/1903.00925\n",
221
- " mu2 = vec.mean().square()\n",
222
- " sigma2 = vec.var()\n",
223
- " return mu2+sigma2-torch.log(sigma2)-1\n",
224
- " \n",
225
- "\n",
226
- "def load_generative_model(gen_model, last_gen_model, experiment_name, model):\n",
227
- " # Load generative model\n",
228
- " if gen_model != last_gen_model:\n",
229
- " model = load_icgan(experiment_name, root_ = '/content')\n",
230
- " last_gen_model = gen_model\n",
231
- " return model, last_gen_model\n",
232
- "\n",
233
- "def load_feature_extractor(gen_model, last_feature_extractor, feature_extractor):\n",
234
- " # Load feature extractor to obtain instance features\n",
235
- " feat_ext_name = 'classification' if gen_model == 'cc_icgan' else 'selfsupervised'\n",
236
- " if last_feature_extractor != feat_ext_name:\n",
237
- " if feat_ext_name == 'classification':\n",
238
- " feat_ext_path = ''\n",
239
- " else:\n",
240
- " !curl -L -o /content/swav_pretrained.pth.tar -C - 'https://dl.fbaipublicfiles.com/deepcluster/swav_800ep_pretrain.pth.tar' \n",
241
- " feat_ext_path = '/content/swav_pretrained.pth.tar'\n",
242
- " last_feature_extractor = feat_ext_name\n",
243
- " feature_extractor = data_utils.load_pretrained_feature_extractor(feat_ext_path, feature_extractor = feat_ext_name)\n",
244
- " feature_extractor.eval()\n",
245
- " return feature_extractor, last_feature_extractor\n",
246
- "\n",
247
- "norm_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1)\n",
248
- "norm_std = torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1)\n",
249
- "\n",
250
- "def preprocess_input_image(input_image_path, size): \n",
251
- " pil_image = Image_PIL.open(input_image_path).convert('RGB')\n",
252
- " transform_list = transforms.Compose([data_utils.CenterCropLongEdge(), transforms.Resize((size,size)), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std)])\n",
253
- " tensor_image = transform_list(pil_image)\n",
254
- " tensor_image = torch.nn.functional.interpolate(tensor_image.unsqueeze(0), 224, mode=\"bicubic\", align_corners=True)\n",
255
- " return tensor_image\n",
256
- "\n",
257
- "def preprocess_generated_image(image): \n",
258
- " transform_list = transforms.Normalize(norm_mean, norm_std)\n",
259
- " image = transform_list(image*0.5 + 0.5)\n",
260
- " image = torch.nn.functional.interpolate(image, 224, mode=\"bicubic\", align_corners=True)\n",
261
- " return image\n",
262
- "\n",
263
- "last_gen_model = None\n",
264
- "last_feature_extractor = None\n",
265
- "model = None\n",
266
- "feature_extractor = None"
267
- ]
268
- },
269
- {
270
- "cell_type": "code",
271
- "execution_count": null,
272
- "id": "17278e04",
273
- "metadata": {},
274
- "outputs": [],
275
- "source": [
276
- "#@title Generate images with IC-GAN!\n",
277
- "#@markdown 1. Select type of IC-GAN model with **gen_model**: \"icgan\" is conditioned on an instance; \"cc_icgan\" is conditioned on both instance and a class index.\n",
278
- "#@markdown 1. Select which instance to condition on, following one of the following options:\n",
279
- "#@markdown 1. **input_image_instance** is the path to an input image, from either the mounted Google Drive or a manually uploaded image to \"Files\" (left part of the screen).\n",
280
- "#@markdown 1. **input_feature_index** write an integer from 0 to 1000. This will change the instance conditioning and therefore the style and semantics of the generated images. This will select one of the 1000 instance features pre-selected from ImageNet using k-means.\n",
281
- "#@markdown 1. For **class_index** (only valid for gen_model=\"cc_icgan\") write an integer from 0 to 1000. This will change the ImageNet class to condition on. Consult [this link](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a) for a correspondence between class name and indexes.\n",
282
- "#@markdown 1. **num_samples_ranked** (default=16) indicates the number of generated images to output in a mosaic. These generated images are the ones that scored a higher cosine similarity with the conditioning instance, out of **num_samples_total** (default=160) generated samples. Increasing \"num_samples_total\" will result in higher run times, but more generated images to choose the top \"num_samples_ranked\" from, and therefore higher chance of better image quality. Reducing \"num_samples_total\" too much could result in generated images with poorer visual quality. A ratio of 10:1 (num_samples_total:num_samples_ranked) is recommended.\n",
283
- "#@markdown 1. Vary **truncation** (default=0.7) from 0 to 1 to apply the [truncation trick](https://arxiv.org/abs/1809.11096). Truncation=1 will provide more diverse but possibly poorer quality images. Trucation values between 0.7 and 0.9 seem to empirically work well.\n",
284
- "#@markdown 1. **seed**=0 means no seed.\n",
285
- "\n",
286
- "gen_model = 'icgan' #@param ['icgan', 'cc_icgan']\n",
287
- "if gen_model == 'icgan': \n",
288
- " experiment_name = 'icgan_biggan_imagenet_res256'\n",
289
- "else:\n",
290
- " experiment_name = 'cc_icgan_biggan_imagenet_res256'\n",
291
- "#last_gen_model = experiment_name\n",
292
- "size = '256'\n",
293
- "input_image_instance = \"\"#@param {type:\"string\"}\n",
294
- "input_feature_index = 3#@param {type:'integer'}\n",
295
- "class_index = 538#@param {type:'integer'}\n",
296
- "num_samples_ranked = 16#@param {type:'integer'}\n",
297
- "num_samples_total = 160#@param {type:'integer'}\n",
298
- "truncation = 0.7#@param {type:'number'}\n",
299
- "stochastic_truncation = False #@param {type:'boolean'}\n",
300
- "download_file = True #@param {type:'boolean'}\n",
301
- "seed = 50#@param {type:'number'}\n",
302
- "if seed == 0:\n",
303
- " seed = None\n",
304
- "noise_size = 128\n",
305
- "class_size = 1000\n",
306
- "channels = 3\n",
307
- "batch_size = 4\n",
308
- "if gen_model == 'icgan':\n",
309
- " class_index = None\n",
310
- "if 'biggan' in gen_model:\n",
311
- " input_feature_index = None\n",
312
- " input_image_instance = None\n",
313
- "\n",
314
- "assert(num_samples_ranked <=num_samples_total)\n",
315
- "import numpy as np\n",
316
- "state = None if not seed else np.random.RandomState(seed)\n",
317
- "np.random.seed(seed)\n",
318
- "\n",
319
- "feature_extractor_name = 'classification' if gen_model == 'cc_icgan' else 'selfsupervised'\n",
320
- "\n",
321
- "# Load feature extractor (outlier filtering and optionally input image feature extraction)\n",
322
- "feature_extractor, last_feature_extractor = load_feature_extractor(gen_model, last_feature_extractor, feature_extractor)\n",
323
- "# Load features \n",
324
- "if input_image_instance not in ['None', \"\"]:\n",
325
- " print('Obtaining instance features from input image!')\n",
326
- " input_feature_index = None\n",
327
- " input_image_tensor = preprocess_input_image(input_image_instance, int(size))\n",
328
- " print('Displaying instance conditioning:')\n",
329
- " display(convert_to_images(((input_image_tensor*norm_std + norm_mean)-0.5) / 0.5)[0])\n",
330
- " with torch.no_grad():\n",
331
- " input_features, _ = feature_extractor(input_image_tensor.cuda())\n",
332
- " input_features/=torch.linalg.norm(input_features,dim=-1, keepdims=True)\n",
333
- "elif input_feature_index is not None:\n",
334
- " print('Selecting an instance from pre-extracted vectors!')\n",
335
- " input_features = np.load('/content/stored_instances/imagenet_res'+str(size)+'_rn50_'+feature_extractor_name+'_kmeans_k1000_instance_features.npy', allow_pickle=True).item()[\"instance_features\"][input_feature_index:input_feature_index+1]\n",
336
- "else:\n",
337
- " input_features = None\n",
338
- "\n",
339
- "# Load generative model\n",
340
- "model, last_gen_model = load_generative_model(gen_model, last_gen_model, experiment_name, model)\n",
341
- "# Prepare other variables\n",
342
- "name_file = '%s_class_index%s_instance_index%s'%(gen_model, str(class_index) if class_index is not None else 'None', str(input_feature_index) if input_feature_index is not None else 'None')\n",
343
- "\n",
344
- "!rm -rf /content/output\n",
345
- "!mkdir -p /content/output\n",
346
- "\n",
347
- "replace_to_inplace_relu(model)\n",
348
- "ind2name = {index: wn.of2ss('%08dn'%offset).lemma_names()[0] for offset, index in utils.IMAGENET.items()}\n",
349
- "\n",
350
- "from google.colab import files, output\n",
351
- "\n",
352
- "eps = 1e-8\n",
353
- "\n",
354
- "# Create noise, instance and class vector\n",
355
- "noise_vector = truncnorm.rvs(-2*truncation, 2*truncation, size=(num_samples_total, noise_size), random_state=state).astype(np.float32) #see https://github.com/tensorflow/hub/issues/214\n",
356
- "noise_vector = torch.tensor(noise_vector, requires_grad=False, device='cuda')\n",
357
- "if input_features is not None:\n",
358
- " instance_vector = torch.tensor(input_features, requires_grad=False, device='cuda').repeat(num_samples_total, 1)\n",
359
- "else: \n",
360
- " instance_vector = None\n",
361
- "if class_index is not None:\n",
362
- " print('Conditioning on class: ', ind2name[class_index])\n",
363
- " input_label = torch.LongTensor([class_index]*num_samples_total)\n",
364
- "else:\n",
365
- " input_label = None\n",
366
- "if input_feature_index is not None:\n",
367
- " print('Conditioning on instance with index: ', input_feature_index)\n",
368
- "\n",
369
- "size = int(size)\n",
370
- "all_outs, all_dists = [], []\n",
371
- "for i_bs in range(num_samples_total//batch_size+1):\n",
372
- " start = i_bs*batch_size\n",
373
- " end = min(start+batch_size, num_samples_total)\n",
374
- " if start == end:\n",
375
- " break\n",
376
- " out = get_output(noise_vector[start:end], input_label[start:end] if input_label is not None else None, instance_vector[start:end] if instance_vector is not None else None)\n",
377
- "\n",
378
- " if instance_vector is not None:\n",
379
- " # Get features from generated images + feature extractor\n",
380
- " out_ = preprocess_generated_image(out)\n",
381
- " with torch.no_grad():\n",
382
- " out_features, _ = feature_extractor(out_.cuda())\n",
383
- " out_features/=torch.linalg.norm(out_features,dim=-1, keepdims=True)\n",
384
- " dists = sklearn.metrics.pairwise_distances(\n",
385
- " out_features.cpu(), instance_vector[start:end].cpu(), metric=\"euclidean\", n_jobs=-1)\n",
386
- " all_dists.append(np.diagonal(dists))\n",
387
- " all_outs.append(out.detach().cpu())\n",
388
- " del (out)\n",
389
- "all_outs = torch.cat(all_outs)\n",
390
- "all_dists = np.concatenate(all_dists)\n",
391
- "\n",
392
- "# Order samples by distance to conditioning feature vector and select only num_samples_ranked images\n",
393
- "selected_idxs =np.argsort(all_dists)[:num_samples_ranked]\n",
394
- "#print('All distances re-ordered ', np.sort(all_dists))\n",
395
- "# Create figure \n",
396
- "row_i, col_i, i_im = 0, 0, 0\n",
397
- "all_images_mosaic = np.zeros((3,size*(int(np.sqrt(num_samples_ranked))), size*(int(np.sqrt(num_samples_ranked)))))\n",
398
- "for j in selected_idxs:\n",
399
- " all_images_mosaic[:,row_i*size:row_i*size+size, col_i*size:col_i*size+size] = all_outs[j]\n",
400
- " if row_i == int(np.sqrt(num_samples_ranked))-1:\n",
401
- " row_i = 0\n",
402
- " if col_i == int(np.sqrt(num_samples_ranked))-1:\n",
403
- " col_i = 0\n",
404
- " else:\n",
405
- " col_i +=1\n",
406
- " else:\n",
407
- " row_i+=1\n",
408
- " i_im +=1\n",
409
- "\n",
410
- "name = '/content/%s_seed%i.png'%(name_file,seed if seed is not None else -1)\n",
411
- "pil_image = save(all_images_mosaic[np.newaxis,...],name, torch_format=False) \n",
412
- "print('Displaying generated images')\n",
413
- "display(pil_image)\n",
414
- "\n",
415
- "if download_file:\n",
416
- " files.download(name)\n",
417
- "\n"
418
- ]
419
- },
420
- {
421
- "cell_type": "code",
422
- "execution_count": null,
423
- "id": "da5ee254",
424
- "metadata": {},
425
- "outputs": [],
426
- "source": [
427
- "#@title Generate images with IC-GAN + CLIP!\n",
428
- "#@markdown 1. For **prompt** OpenAI suggest to use the template \"A photo of a X.\" or \"A photo of a X, a type of Y.\" [[paper]](https://cdn.openai.com/papers/Learning_Transferable_Visual_Models_From_Natural_Language_Supervision.pdf)\n",
429
- "#@markdown 1. Select type of IC-GAN model with **gen_model**: \"icgan\" is conditioned on an instance; \"cc_icgan\" is conditioned on both instance and a class index.\n",
430
- "#@markdown 1. Select which instance to condition on, following one of the following options:\n",
431
- "#@markdown 1. **input_image_instance** is the path to an input image, from either the mounted Google Drive or a manually uploaded image to \"Files\" (left part of the screen).\n",
432
- "#@markdown 1. **input_feature_index** write an integer from 0 to 1000. This will change the instance conditioning and therefore the style and semantics of the generated images. This will select one of the 1000 instance features pre-selected from ImageNet using k-means.\n",
433
- "#@markdown 1. For **class_index** (only valid for gen_model=\"cc_icgan\") write an integer from 0 to 1000. This will change the ImageNet class to condition on. Consult [this link](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a) for a correspondence between class name and indexes.\n",
434
- "#@markdown 1. Vary **truncation** from 0 to 1 to apply the [truncation trick](https://arxiv.org/abs/1809.11096). Truncation=1 will provide more diverse but possibly poorer quality images. Trucation values between 0.7 and 0.9 seem to empirically work well.\n",
435
- "#@markdown 4. **seed**=0 means no seed.\n",
436
- "prompt = 'A dragon' #@param {type:'string'}\n",
437
- "gen_model = 'icgan' #@param ['icgan', 'cc_icgan']\n",
438
- "if gen_model == 'icgan': \n",
439
- " experiment_name = 'icgan_biggan_imagenet_res256_nofeataug'\n",
440
- "else:\n",
441
- " experiment_name = 'cc_icgan_biggan_imagenet_res256_nofeataug'\n",
442
- "#last_gen_model = experiment_name\n",
443
- "size = '256'\n",
444
- "input_image_instance = \"\"#@param {type:\"string\"}\n",
445
- "\n",
446
- "input_feature_index = 500#@param {type:'integer'}\n",
447
- "class_index = 627 #@param {type:'integer'} (only with cc_icgan)\n",
448
- "download_image = False #@param {type:'boolean'}\n",
449
- "download_video = False #@param {type:'boolean'}\n",
450
- "truncation = 0.7 #@param {type:'number'}\n",
451
- "stochastic_truncation = False #@param {type:'boolean'}\n",
452
- "optimizer = 'CMA-ES' #@param ['SGD','Adam','CMA-ES','CMA-ES + SGD interleaved','CMA-ES + Adam interleaved','CMA-ES + terminal SGD','CMA-ES + terminal Adam']\n",
453
- "pop_size = 50 #@param {type:'integer'}\n",
454
- "clip_model = 'ViT-B/32' #@param ['ViT-B/32','RN50','RN101','RN50x4']\n",
455
- "augmentations = 64#@param {type:'integer'}\n",
456
- "learning_rate = 0.1#@param {type:'number'}\n",
457
- "noise_normality_loss = 0#@param {type:'number'}\n",
458
- "minimum_entropy_loss = 0.0001 #@param {type:'number'}\n",
459
- "total_variation_loss = 0.1 #@param {type:'number'}\n",
460
- "iterations = 100#@param {type:'integer'}\n",
461
- "terminal_iterations = 100#@param {type:'integer'}\n",
462
- "show_every = 1 #@param {type:'integer'}\n",
463
- "save_every = 1 #@param {type:'integer'}\n",
464
- "fps = 2#@param {type:'number'}\n",
465
- "freeze_secs = 0 #@param {type:'number'}\n",
466
- "seed = 10#@param {type:'number'}\n",
467
- "if seed == 0:\n",
468
- " seed = None\n",
469
- "\n",
470
- "softmax_temp = 1\n",
471
- "emb_factor = 0.067 #calculated empirically \n",
472
- "loss_factor = 100\n",
473
- "sigma0 = 0.5 #http://cma.gforge.inria.fr/cmaes_sourcecode_page.html#practical\n",
474
- "cma_adapt = True\n",
475
- "cma_diag = False\n",
476
- "cma_active = True\n",
477
- "cma_elitist = False\n",
478
- "noise_size = 128\n",
479
- "class_size = 1000\n",
480
- "channels = 3\n",
481
- "if gen_model == 'icgan':\n",
482
- " class_index = None\n",
483
- "\n",
484
- "import numpy as np\n",
485
- "state = None if not seed else np.random.RandomState(seed)\n",
486
- "np.random.seed(seed)\n",
487
- "# Load features \n",
488
- "if input_image_instance not in ['None',\"\"]:\n",
489
- " print('Obtaining instance features from input image!')\n",
490
- " input_feature_index = None\n",
491
- " feature_extractor, last_feature_extractor = load_feature_extractor(gen_model, last_feature_extractor, feature_extractor)\n",
492
- " input_image_tensor = preprocess_input_image(input_image_instance, int(size))\n",
493
- " input_features, _ = feature_extractor(input_image_tensor.cuda())\n",
494
- " input_features/=torch.linalg.norm(input_features,dim=-1, keepdims=True)\n",
495
- "elif input_feature_index is not None:\n",
496
- " print('Selecting an instance from pre-extracted vectors!')\n",
497
- " feature_extractor_name = 'classification' if gen_model == 'cc_icgan' else 'selfsupervised'\n",
498
- " input_features = np.load('/content/stored_instances/imagenet_res'+str(size)+'_rn50_'+feature_extractor_name+'_kmeans_k1000_instance_features.npy', allow_pickle=True).item()[\"instance_features\"][input_feature_index:input_feature_index+1]\n",
499
- "else:\n",
500
- " input_features = None\n",
501
- "\n",
502
- "\n",
503
- "# Load generative model\n",
504
- "model, last_gen_model = load_generative_model(gen_model, last_gen_model, experiment_name, model)\n",
505
- "\n",
506
- "# Load CLIP model\n",
507
- "if clip_model != last_clip_model:\n",
508
- " perceptor, preprocess = clip.load(clip_model)\n",
509
- " last_clip_model = clip_model\n",
510
- "clip_res = perceptor.visual.input_resolution\n",
511
- "sideX = sideY = int(size)\n",
512
- "if sideX<=clip_res and sideY<=clip_res:\n",
513
- " augmentations = 1\n",
514
- "if 'CMA' not in optimizer:\n",
515
- " pop_size = 1\n",
516
- "\n",
517
- "# Prepare other variables\n",
518
- "name_file = '%s_%s_class_index%s_instance_index%s'%(gen_model, prompt, str(class_index) if class_index is not None else 'None', str(input_feature_index) if input_feature_index is not None else 'None')\n",
519
- "requires_grad = ('SGD' in optimizer or 'Adam' in optimizer) and ('terminal' not in optimizer or terminal_iterations>0)\n",
520
- "total_iterations = iterations + terminal_iterations*('terminal' in optimizer)\n",
521
- "\n",
522
- "!rm -rf /content/output\n",
523
- "!mkdir -p /content/output\n",
524
- "\n",
525
- "replace_to_inplace_relu(model)\n",
526
- "replace_to_inplace_relu(perceptor)\n",
527
- "ind2name = {index: wn.of2ss('%08dn'%offset).lemma_names()[0] for offset, index in utils.IMAGENET.items()}\n",
528
- "eps = 1e-8\n",
529
- "\n",
530
- "# Create noise and instance vector\n",
531
- "noise_vector = truncnorm.rvs(-2*truncation, 2*truncation, size=(pop_size, noise_size), random_state=state).astype(np.float32) #see https://github.com/tensorflow/hub/issues/214\n",
532
- "noise_vector = torch.tensor(noise_vector, requires_grad=requires_grad, device='cuda')\n",
533
- "if input_features is not None:\n",
534
- " instance_vector = torch.tensor(input_features, requires_grad=False, device='cuda')\n",
535
- "else: \n",
536
- " instance_vector = None\n",
537
- "if class_index is not None:\n",
538
- " print('Conditioning on class: ', ind2name[class_index])\n",
539
- "if input_feature_index is not None:\n",
540
- " print('Conditioning on instance with index: ', input_feature_index)\n",
541
- "\n",
542
- "# Prepare optimizer\n",
543
- "if requires_grad:\n",
544
- " params = [noise_vector]\n",
545
- " if 'SGD' in optimizer:\n",
546
- " optim = torch.optim.SGD(params, lr=learning_rate, momentum=0.9) \n",
547
- " else:\n",
548
- " optim = torch.optim.Adam(params, lr=learning_rate)\n",
549
- "\n",
550
- "def ascend_txt(i, grad_step=False, show_save=False):\n",
551
- " global global_best_loss, global_best_iteration, global_best_noise_vector, global_best_class_vector\n",
552
- " regs = []\n",
553
- " losses = []\n",
554
- " total_losses = []\n",
555
- " best_loss = np.inf\n",
556
- " global_reg = torch.tensor(0, device='cuda', dtype=torch.float32, requires_grad=grad_step)\n",
557
- " if noise_normality_loss:\n",
558
- " global_reg = global_reg+noise_normality_loss*normality_loss(noise_vector)\n",
559
- " global_reg = loss_factor*global_reg \n",
560
- " if grad_step:\n",
561
- " global_reg.backward()\n",
562
- " for j in range(pop_size):\n",
563
- " p_s = []\n",
564
- " out = get_output(noise_vector[j:j+1], [class_index] if class_index is not None else None, instance_vector)\n",
565
- " for aug in range(augmentations):\n",
566
- " if sideX<=clip_res and sideY<=clip_res or augmentations==1:\n",
567
- " apper = out \n",
568
- " else:\n",
569
- " size = torch.randint(int(.7*sideX), int(.98*sideX), ())\n",
570
- " offsetx = torch.randint(0, sideX - size, ())\n",
571
- " offsety = torch.randint(0, sideX - size, ())\n",
572
- " apper = out[:, :, offsetx:offsetx + size, offsety:offsety + size]\n",
573
- " apper = (apper+1)/2\n",
574
- " apper = nn.functional.interpolate(apper, clip_res, mode='bilinear')\n",
575
- " #apper = apper.clamp(0,1)\n",
576
- " p_s.append(apper)\n",
577
- " into = nom(torch.cat(p_s, 0))\n",
578
- " predict_clip = perceptor.encode_image(into)\n",
579
- " loss = loss_factor*(1-torch.cosine_similarity(predict_clip, target_clip).mean())\n",
580
- " total_loss = loss\n",
581
- " regs.append(global_reg.item())\n",
582
- "\n",
583
- " with torch.no_grad():\n",
584
- " losses.append(loss.item())\n",
585
- " total_losses.append(total_loss.item()+global_reg.item())\n",
586
- " if total_losses[-1]<best_loss:\n",
587
- " best_loss = total_losses[-1]\n",
588
- " best_ind = j\n",
589
- " best_out = out\n",
590
- " if best_loss < global_best_loss:\n",
591
- " global_best_loss = best_loss\n",
592
- " global_best_iteration = i\n",
593
- " with torch.no_grad():\n",
594
- " global_best_noise_vector = noise_vector[best_ind]\n",
595
- " if grad_step: \n",
596
- " total_loss.backward()\n",
597
- "\n",
598
- " if grad_step:\n",
599
- " optim.step()\n",
600
- " optim.zero_grad()\n",
601
- "\n",
602
- " if show_save and (save_every and i % save_every == 0 or show_every and i % show_every == 0):\n",
603
- " noise = None\n",
604
- " emb = None\n",
605
- " with torch.no_grad():\n",
606
- " noise = noise_vector.cpu().numpy()\n",
607
- " checkin(i, best_ind, total_losses, losses, regs, best_out, noise, emb) \n",
608
- " return total_losses, best_ind\n",
609
- "\n",
610
- "# Obtain target CLIP representation\n",
611
- "tx = clip.tokenize(prompt)\n",
612
- "with torch.no_grad():\n",
613
- " target_clip = perceptor.encode_text(tx.cuda())\n",
614
- "\n",
615
- "\n",
616
- "global_best_loss = np.inf\n",
617
- "global_best_iteration = 0\n",
618
- "global_best_noise_vector = None\n",
619
- "\n",
620
- "nom = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n",
621
- "if 'CMA' in optimizer:\n",
622
- " initial_vector = np.zeros(noise_size)\n",
623
- " bounds = None\n",
624
- " cma_opts = {'popsize': pop_size, 'seed': np.nan, 'AdaptSigma': cma_adapt, 'CMA_diagonal': cma_diag, 'CMA_active': cma_active, 'CMA_elitist':cma_elitist, 'bounds':bounds}\n",
625
- " cmaes = cma.CMAEvolutionStrategy(initial_vector, sigma0, inopts=cma_opts)\n",
626
- "\n",
627
- "sample_num = 0\n",
628
- "machine = !nvidia-smi -L\n",
629
- "start = time()\n",
630
- "\n",
631
- "# Start noise vector optimization\n",
632
- "for i in range(total_iterations): \n",
633
- " if 'CMA' in optimizer and i<iterations:\n",
634
- " with torch.no_grad():\n",
635
- " cma_results = torch.tensor(cmaes.ask(), dtype=torch.float32).cuda()\n",
636
- " noise_vector.data = cma_results \n",
637
- " if requires_grad and ('terminal' not in optimizer or i>=iterations):\n",
638
- " losses, best_ind = ascend_txt(i, grad_step=True, show_save='CMA' not in optimizer or i>=iterations)\n",
639
- " assert noise_vector.requires_grad and noise_vector.is_leaf and (not optimize_class or class_vector.requires_grad and class_vector.is_leaf), (noise_vector.requires_grad, noise_vector.is_leaf, class_vector.requires_grad, class_vector.is_leaf)\n",
640
- " if 'CMA' in optimizer and i<iterations:\n",
641
- " with torch.no_grad():\n",
642
- " losses, best_ind = ascend_txt(i, show_save=True)\n",
643
- " if i<iterations-1:\n",
644
- " vectors = noise_vector\n",
645
- " cmaes.tell(vectors.cpu().numpy(), losses)\n",
646
- " elif 'terminal' in optimizer and terminal_iterations:\n",
647
- " pop_size = 1\n",
648
- " noise_vector[0] = global_best_noise_vector\n",
649
- " if save_every and i % save_every == 0 or show_every and i % show_every == 0:\n",
650
- " print('took: %d secs (%.2f sec/iter) on %s. CUDA memory: %.1f GB'%(time()-start,(time()-start)/(i+1), machine[0], torch.cuda.max_memory_allocated()/1024**3))\n",
651
- "\n",
652
- "# Obtain generated image with lowest loss.\n",
653
- "out = get_output(global_best_noise_vector.unsqueeze(0), [class_index] if class_index is not None else None, instance_vector)\n",
654
- "name = '/content/%s_best_seed%i.png'%(name_file,seed if seed is not None else -1)\n",
655
- "pil_image = save(out,name) \n",
656
- "display(pil_image) \n",
657
- "print('best_loss=%.2f best_iter=%d'%(global_best_loss,global_best_iteration))\n",
658
- "\n",
659
- "if download_image:\n",
660
- " from google.colab import files, output\n",
661
- " files.download(name)\n",
662
- "\n",
663
- "if download_video:\n",
664
- " out = '\"/content/%s_seed%i.mp4\"'%(name_file, seed if seed is not None else -1)\n",
665
- " file_name = '/content/%s_seed%i.mp4'%(name_file, seed if seed is not None else -1)\n",
666
- "\n",
667
- " with open('/content/list.txt','w') as f:\n",
668
- " for i in range(sample_num):\n",
669
- " f.write('file /content/output/frame_%05d.jpg\\n'%i)\n",
670
- " for j in range(int(freeze_secs*fps)):\n",
671
- " f.write('file /content/output/frame_%05d.jpg\\n'%i)\n",
672
- " !ffmpeg -r $fps -f concat -safe 0 -i /content/list.txt -c:v libx264 -pix_fmt yuv420p -profile:v baseline -movflags +faststart -r $fps $out -y\n",
673
- " with open(file_name, 'rb') as f:\n",
674
- " data_url = \"data:video/mp4;base64,\" + b64encode(f.read()).decode()\n",
675
- " display(HTML(\"\"\"\n",
676
- " <video controls autoplay loop>\n",
677
- " <source src=\"%s\" type=\"video/mp4\">\n",
678
- " </video>\"\"\" % data_url))\n",
679
- "\n",
680
- " from google.colab import files, output\n",
681
- " output.eval_js('new Audio(\"https://freesound.org/data/previews/80/80921_1022651-lq.ogg\").play()')\n",
682
- " files.download(file_name)"
683
- ]
684
- }
685
- ],
686
- "metadata": {
687
- "kernelspec": {
688
- "display_name": "Python 3",
689
- "language": "python",
690
- "name": "python3"
691
- },
692
- "language_info": {
693
- "codemirror_mode": {
694
- "name": "ipython",
695
- "version": 3
696
- },
697
- "file_extension": ".py",
698
- "mimetype": "text/x-python",
699
- "name": "python",
700
- "nbconvert_exporter": "python",
701
- "pygments_lexer": "ipython3",
702
- "version": "3.8.10"
703
- }
704
- },
705
- "nbformat": 4,
706
- "nbformat_minor": 5
707
- }