facehuggingjay commited on
Commit
5269ed5
·
verified ·
1 Parent(s): e972aca

eehhhhhhh!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

Browse files
Files changed (1) hide show
  1. app.py +1368 -574
app.py CHANGED
@@ -1,580 +1,1374 @@
1
- import spaces
2
- import contextlib
3
- import gc
4
- import json
5
- import logging
6
- import math
7
- import os
8
- import random
9
- import shutil
10
- import sys
11
- import time
12
- import itertools
13
- from pathlib import Path
14
-
15
- import cv2
16
- import numpy as np
17
- from PIL import Image, ImageDraw
18
  import torch
 
 
 
 
 
19
  import torch.nn.functional as F
20
- import torch.utils.checkpoint
21
- from torch.utils.data import Dataset
22
- from torchvision import transforms
23
- from tqdm.auto import tqdm
24
-
25
- import accelerate
26
- from accelerate import Accelerator
27
- from accelerate.logging import get_logger
28
- from accelerate.utils import ProjectConfiguration, set_seed
29
-
30
- from datasets import load_dataset
31
- from huggingface_hub import create_repo, upload_folder
32
- from packaging import version
33
- from safetensors.torch import load_model
34
- from peft import LoraConfig
35
- import gradio as gr
36
- import pandas as pd
37
-
38
- import transformers
39
- from transformers import (
40
- AutoTokenizer,
41
- PretrainedConfig,
42
- CLIPVisionModelWithProjection,
43
- CLIPImageProcessor,
44
- CLIPProcessor,
45
- )
46
-
47
- import diffusers
48
- from diffusers import (
49
- AutoencoderKL,
50
- DDPMScheduler,
51
- ColorGuiderPixArtModel,
52
- ColorGuiderSDModel,
53
- UNet2DConditionModel,
54
- PixArtTransformer2DModel,
55
- ColorFlowPixArtAlphaPipeline,
56
- ColorFlowSDPipeline,
57
- UniPCMultistepScheduler,
58
- )
59
- from colorflow_utils.utils import *
60
-
61
- sys.path.append('./BidirectionalTranslation')
62
- from options.test_options import TestOptions
63
- from models import create_model
64
- from util import util
65
-
66
- from huggingface_hub import snapshot_download
67
-
68
-
69
- article = r"""
70
- If ColorFlow is helpful, please help to ⭐ the <a href='https://github.com/TencentARC/ColorFlow' target='_blank'>Github Repo</a>. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/TencentARC/ColorFlow)](https://github.com/TencentARC/ColorFlow)
71
- ---
72
-
73
- 📧 **Contact**
74
- <br>
75
- If you have any questions, please feel free to reach me out at <b>zhuangjh23@mails.tsinghua.edu.cn</b>.
76
-
77
- 📝 **Citation**
78
- <br>
79
- If our work is useful for your research, please consider citing:
80
- ```bibtex
81
- @misc{zhuang2024colorflow,
82
- title={ColorFlow: Retrieval-Augmented Image Sequence Colorization},
83
- author={Junhao Zhuang and Xuan Ju and Zhaoyang Zhang and Yong Liu and Shiyi Zhang and Chun Yuan and Ying Shan},
84
- year={2024},
85
- eprint={2412.11815},
86
- archivePrefix={arXiv},
87
- primaryClass={cs.CV},
88
- url={https://arxiv.org/abs/2412.11815},
89
- }
90
- ```
91
- """
92
-
93
- model_global_path = snapshot_download(repo_id="TencentARC/ColorFlow", cache_dir='./colorflow/', repo_type="model")
94
- print(model_global_path)
95
-
96
-
97
- transform = transforms.Compose([
98
- transforms.ToTensor(),
99
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
100
- ])
101
- weight_dtype = torch.float16
102
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
103
-
104
- # line model
105
- line_model_path = model_global_path + '/LE/erika.pth'
106
- line_model = res_skip()
107
- line_model.load_state_dict(torch.load(line_model_path))
108
- line_model.eval()
109
- line_model.to(device)
110
-
111
- # screen model
112
- global opt
113
-
114
- opt = TestOptions().parse(model_global_path)
115
- ScreenModel = create_model(opt, model_global_path)
116
- ScreenModel.setup(opt)
117
- ScreenModel.eval()
118
-
119
- image_processor = CLIPImageProcessor()
120
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(model_global_path + '/image_encoder/').to(device)
121
-
122
-
123
- examples = [
124
- [
125
- "./assets/example_5/input.png",
126
- ["./assets/example_5/ref1.png", "./assets/example_5/ref2.png", "./assets/example_5/ref3.png"],
127
- "GrayImage(ScreenStyle)",
128
- "800x512",
129
- 0,
130
- 10
131
- ],
132
- [
133
- "./assets/example_4/input.jpg",
134
- ["./assets/example_4/ref1.jpg", "./assets/example_4/ref2.jpg", "./assets/example_4/ref3.jpg"],
135
- "GrayImage(ScreenStyle)",
136
- "640x640",
137
- 0,
138
- 10
139
- ],
140
- [
141
- "./assets/example_3/input.png",
142
- ["./assets/example_3/ref1.png", "./assets/example_3/ref2.png", "./assets/example_3/ref3.png"],
143
- "GrayImage(ScreenStyle)",
144
- "800x512",
145
- 0,
146
- 10
147
- ],
148
- [
149
- "./assets/example_2/input.png",
150
- ["./assets/example_2/ref1.png", "./assets/example_2/ref2.png", "./assets/example_2/ref3.png"],
151
- "GrayImage(ScreenStyle)",
152
- "800x512",
153
- 0,
154
- 10
155
- ],
156
- [
157
- "./assets/example_6/input.png",
158
- ["./assets/example_6/ref1.png", "./assets/example_6/ref2.png", "./assets/example_6/ref3.png"],
159
- "Sketch_Shading",
160
- "512x800",
161
- 0,
162
- 10
163
- ],
164
- [
165
- "./assets/example_7/input.jpg",
166
- ["./assets/example_7/ref1.jpg", "./assets/example_7/ref2.jpg", "./assets/example_7/ref3.jpg", "./assets/example_7/ref4.jpg"],
167
- "Sketch_Shading",
168
- "640x640",
169
- 2,
170
- 10
171
- ],
172
- [
173
- "./assets/example_1/input.jpg",
174
- ["./assets/example_1/ref1.jpg", "./assets/example_1/ref2.jpg", "./assets/example_1/ref3.jpg"],
175
- "Sketch",
176
- "640x640",
177
- 1,
178
- 10
179
- ],
180
- [
181
- "./assets/example_0/input.jpg",
182
- ["./assets/example_0/ref1.jpg"],
183
- "Sketch",
184
- "640x640",
185
- 1,
186
- 10
187
- ],
188
- ]
189
-
190
- global pipeline
191
- global MultiResNetModel
192
-
193
- @spaces.GPU
194
- def load_ckpt(input_style):
195
- global pipeline
196
- global MultiResNetModel
197
- if input_style == "Sketch" or input_style == "Sketch_Shading":
198
- if input_style == "Sketch":
199
- ckpt_path = model_global_path + '/sketch/'
200
- rank = 128
201
- else:
202
- ckpt_path = model_global_path + '/shading/'
203
- rank = 128
204
- pretrained_model_name_or_path = 'PixArt-alpha/PixArt-XL-2-1024-MS'
205
- transformer = PixArtTransformer2DModel.from_pretrained(
206
- pretrained_model_name_or_path, subfolder="transformer", revision=None, variant=None
207
- )
208
- pixart_config = get_pixart_config()
209
-
210
- ColorGuider = ColorGuiderPixArtModel.from_pretrained(ckpt_path)
211
-
212
- transformer_lora_config = LoraConfig(
213
- r=rank,
214
- lora_alpha=rank,
215
- init_lora_weights="gaussian",
216
- target_modules=["to_k", "to_q", "to_v", "to_out.0", "proj_in", "proj_out", "ff.net.0.proj", "ff.net.2", "proj", "linear", "linear_1", "linear_2"]
217
- )
218
- transformer.add_adapter(transformer_lora_config)
219
- ckpt_key_t = torch.load(ckpt_path + 'transformer_lora.bin', map_location='cpu')
220
- transformer.load_state_dict(ckpt_key_t, strict=False)
221
-
222
- transformer.to(device, dtype=weight_dtype)
223
- ColorGuider.to(device, dtype=weight_dtype)
224
-
225
- pipeline = ColorFlowPixArtAlphaPipeline.from_pretrained(
226
- pretrained_model_name_or_path,
227
- transformer=transformer,
228
- colorguider=ColorGuider,
229
- safety_checker=None,
230
- revision=None,
231
- variant=None,
232
- torch_dtype=weight_dtype,
233
- )
234
- pipeline = pipeline.to(device)
235
- block_out_channels = [128, 128, 256, 512, 512]
236
-
237
- MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
238
- MultiResNetModel.load_state_dict(torch.load(ckpt_path + 'MultiResNetModel.bin', map_location='cpu'), strict=False)
239
- MultiResNetModel.to(device, dtype=weight_dtype)
240
-
241
- elif input_style == "GrayImage(ScreenStyle)":
242
- ckpt_path = model_global_path + '/GraySD/'
243
- rank = 64
244
- pretrained_model_name_or_path = 'stable-diffusion-v1-5/stable-diffusion-v1-5'
245
- unet = UNet2DConditionModel.from_pretrained(
246
- pretrained_model_name_or_path, subfolder="unet", revision=None, variant=None
247
- )
248
- ColorGuider = ColorGuiderSDModel.from_pretrained(ckpt_path)
249
- ColorGuider.to(device, dtype=weight_dtype)
250
- unet.to(device, dtype=weight_dtype)
251
-
252
- pipeline = ColorFlowSDPipeline.from_pretrained(
253
- pretrained_model_name_or_path,
254
- unet=unet,
255
- colorguider=ColorGuider,
256
- safety_checker=None,
257
- revision=None,
258
- variant=None,
259
- torch_dtype=weight_dtype,
260
- )
261
- pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
262
- unet_lora_config = LoraConfig(
263
- r=rank,
264
- lora_alpha=rank,
265
- init_lora_weights="gaussian",
266
- target_modules=["to_k", "to_q", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2"],#ff.net.0.proj ff.net.2
267
- )
268
- pipeline.unet.add_adapter(unet_lora_config)
269
- pipeline.unet.load_state_dict(torch.load(ckpt_path + 'unet_lora.bin', map_location='cpu'), strict=False)
270
- pipeline = pipeline.to(device)
271
- block_out_channels = [128, 128, 256, 512, 512]
272
-
273
- MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
274
- MultiResNetModel.load_state_dict(torch.load(ckpt_path + 'MultiResNetModel.bin', map_location='cpu'), strict=False)
275
- MultiResNetModel.to(device, dtype=weight_dtype)
276
-
277
-
278
-
279
-
280
-
281
- global cur_input_style
282
- cur_input_style = "Sketch"
283
- load_ckpt(cur_input_style)
284
- cur_input_style = "Sketch_Shading"
285
- load_ckpt(cur_input_style)
286
- cur_input_style = "GrayImage(ScreenStyle)"
287
- load_ckpt(cur_input_style)
288
- cur_input_style = None
289
-
290
- @spaces.GPU
291
- def fix_random_seeds(seed):
292
- random.seed(seed)
293
- np.random.seed(seed)
294
- torch.manual_seed(seed)
295
- if torch.cuda.is_available():
296
- torch.cuda.manual_seed(seed)
297
- torch.cuda.manual_seed_all(seed)
298
-
299
- def process_multi_images(files):
300
- images = [Image.open(file.name) for file in files]
301
- imgs = []
302
- for i, img in enumerate(images):
303
- imgs.append(img)
304
- return imgs
305
-
306
- @spaces.GPU
307
- def extract_lines(image):
308
- src = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
309
-
310
- rows = int(np.ceil(src.shape[0] / 16)) * 16
311
- cols = int(np.ceil(src.shape[1] / 16)) * 16
312
-
313
- patch = np.ones((1, 1, rows, cols), dtype="float32")
314
- patch[0, 0, 0:src.shape[0], 0:src.shape[1]] = src
315
-
316
- tensor = torch.from_numpy(patch).to(device)
317
-
318
- with torch.no_grad():
319
- y = line_model(tensor)
320
-
321
- yc = y.cpu().numpy()[0, 0, :, :]
322
- yc[yc > 255] = 255
323
- yc[yc < 0] = 0
324
-
325
- outimg = yc[0:src.shape[0], 0:src.shape[1]]
326
- outimg = outimg.astype(np.uint8)
327
- outimg = Image.fromarray(outimg)
328
- torch.cuda.empty_cache()
329
- return outimg
330
-
331
- @spaces.GPU
332
- def to_screen_image(input_image):
333
- global opt
334
- global ScreenModel
335
- input_image = input_image.convert('RGB')
336
- input_image = get_ScreenVAE_input(input_image, opt)
337
- h = input_image['h']
338
- w = input_image['w']
339
- ScreenModel.set_input(input_image)
340
- fake_B, fake_B2, SCR = ScreenModel.forward(AtoB=True)
341
- images=fake_B2[:,:,:h,:w]
342
- im = util.tensor2im(images)
343
- image_pil = Image.fromarray(im)
344
- torch.cuda.empty_cache()
345
- return image_pil
346
-
347
- @spaces.GPU
348
- def extract_line_image(query_image_, input_style, resolution):
349
- if resolution == "640x640":
350
- tar_width = 640
351
- tar_height = 640
352
- elif resolution == "512x800":
353
- tar_width = 512
354
- tar_height = 800
355
- elif resolution == "800x512":
356
- tar_width = 800
357
- tar_height = 512
358
  else:
359
- gr.Info("Unsupported resolution")
360
-
361
- query_image = process_image(query_image_, int(tar_width*1.5), int(tar_height*1.5))
362
- if input_style == "GrayImage(ScreenStyle)":
363
- extracted_line = to_screen_image(query_image)
364
- extracted_line = Image.blend(extracted_line.convert('L').convert('RGB'), query_image.convert('L').convert('RGB'), 0.5)
365
- input_context = extracted_line
366
- elif input_style == "Sketch":
367
- query_image = query_image.convert('L').convert('RGB')
368
- extracted_line = extract_lines(query_image)
369
- extracted_line = extracted_line.convert('L').convert('RGB')
370
- input_context = extracted_line
371
- elif input_style == "Sketch_Shading":
372
- query_image = query_image.convert('L').convert('RGB')
373
- extracted_line = extract_lines(query_image)
374
- extracted_line = extracted_line.convert('L').convert('RGB')
375
- array1 = np.array(query_image)
376
- array2 = np.array(extracted_line)
377
- array2[array1 < 0.3 * 255.0] = 0
378
- gray_rate = 125
379
- up_bound = 145
380
- array2[(array2 > gray_rate) & (array1 < up_bound) & (array1 > 0.3 * 255.0)] = gray_rate
381
- input_context = Image.fromarray(np.uint8(array2))
382
- torch.cuda.empty_cache()
383
- return input_context, extracted_line, input_context
384
-
385
- @spaces.GPU(duration=180)
386
- def colorize_image(VAE_input, input_context, reference_images, resolution, seed, input_style, num_inference_steps):
387
- if VAE_input is None or input_context is None:
388
- gr.Info("Please preprocess the image first")
389
- raise ValueError("Please preprocess the image first")
390
- global cur_input_style
391
- global pipeline
392
- global MultiResNetModel
393
- if input_style != cur_input_style:
394
- gr.Info(f"Loading {input_style} model...")
395
- load_ckpt(input_style)
396
- cur_input_style = input_style
397
- gr.Info(f"{input_style} model loaded")
398
- reference_images = process_multi_images(reference_images)
399
- fix_random_seeds(seed)
400
- if resolution == "640x640":
401
- tar_width = 640
402
- tar_height = 640
403
- elif resolution == "512x800":
404
- tar_width = 512
405
- tar_height = 800
406
- elif resolution == "800x512":
407
- tar_width = 800
408
- tar_height = 512
409
  else:
410
- gr.Info("Unsupported resolution")
411
- validation_mask = Image.open('./assets/mask.png').convert('RGB').resize((tar_width*2, tar_height*2))
412
- gr.Info("Image retrieval in progress...")
413
- query_image_bw = process_image(input_context, int(tar_width), int(tar_height))
414
- query_image = query_image_bw.convert('RGB')
415
- query_image_vae = process_image(VAE_input, int(tar_width*1.5), int(tar_height*1.5))
416
- reference_images = [process_image(ref_image, tar_width, tar_height) for ref_image in reference_images]
417
- query_patches_pil = process_image_Q_varres(query_image, tar_width, tar_height)
418
- reference_patches_pil = []
419
- for reference_image in reference_images:
420
- reference_patches_pil += process_image_ref_varres(reference_image, tar_width, tar_height)
421
- combined_image = None
422
- with torch.no_grad():
423
- clip_img = image_processor(images=query_patches_pil, return_tensors="pt").pixel_values.to(image_encoder.device, dtype=image_encoder.dtype)
424
- query_embeddings = image_encoder(clip_img).image_embeds
425
- reference_patches_pil_gray = [rimg.convert('RGB').convert('RGB') for rimg in reference_patches_pil]
426
- clip_img = image_processor(images=reference_patches_pil_gray, return_tensors="pt").pixel_values.to(image_encoder.device, dtype=image_encoder.dtype)
427
- reference_embeddings = image_encoder(clip_img).image_embeds
428
- cosine_similarities = F.cosine_similarity(query_embeddings.unsqueeze(1), reference_embeddings.unsqueeze(0), dim=-1)
429
- sorted_indices = torch.argsort(cosine_similarities, descending=True, dim=1).tolist()
430
- top_k = 3
431
- top_k_indices = [cur_sortlist[:top_k] for cur_sortlist in sorted_indices]
432
- combined_image = Image.new('RGB', (tar_width * 2, tar_height * 2), 'white')
433
- combined_image.paste(query_image_bw.resize((tar_width, tar_height)), (tar_width//2, tar_height//2))
434
- idx_table = {0:[(1,0), (0,1), (0,0)], 1:[(1,3), (0,2),(0,3)], 2:[(2,0),(3,1), (3,0)], 3:[(2,3), (3,2),(3,3)]}
435
- for i in range(2):
436
- for j in range(2):
437
- idx_list = idx_table[i * 2 + j]
438
- for k in range(top_k):
439
- ref_index = top_k_indices[i * 2 + j][k]
440
- idx_y = idx_list[k][0]
441
- idx_x = idx_list[k][1]
442
- combined_image.paste(reference_patches_pil[ref_index].resize((tar_width//2-2, tar_height//2-2)), (tar_width//2 * idx_x + 1, tar_height//2 * idx_y + 1))
443
- gr.Info("Model inference in progress...")
444
- generator = torch.Generator(device=device).manual_seed(seed)
445
- image = pipeline(
446
- "manga", cond_image=combined_image, cond_mask=validation_mask, num_inference_steps=num_inference_steps, generator=generator
447
- ).images[0]
448
- gr.Info("Post-processing image...")
449
- with torch.no_grad():
450
- width, height = image.size
451
- new_width = width // 2
452
- new_height = height // 2
453
- left = (width - new_width) // 2
454
- top = (height - new_height) // 2
455
- right = left + new_width
456
- bottom = top + new_height
457
- center_crop = image.crop((left, top, right, bottom))
458
- up_img = center_crop.resize(query_image_vae.size)
459
- test_low_color = transform(up_img).unsqueeze(0).to(device, dtype=weight_dtype)
460
- query_image_vae = transform(query_image_vae).unsqueeze(0).to(device, dtype=weight_dtype)
461
-
462
- h_color, hidden_list_color = pipeline.vae._encode(test_low_color,return_dict = False, hidden_flag = True)
463
- h_bw, hidden_list_bw = pipeline.vae._encode(query_image_vae, return_dict = False, hidden_flag = True)
464
-
465
- hidden_list_double = [torch.cat((hidden_list_color[hidden_idx], hidden_list_bw[hidden_idx]), dim = 1) for hidden_idx in range(len(hidden_list_color))]
466
-
467
-
468
- hidden_list = MultiResNetModel(hidden_list_double)
469
- output = pipeline.vae._decode(h_color.sample(),return_dict = False, hidden_list = hidden_list)[0]
470
-
471
- output[output > 1] = 1
472
- output[output < -1] = -1
473
- high_res_image = Image.fromarray(((output[0] * 0.5 + 0.5).permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8)).convert("RGB")
474
- gr.Info("Colorization complete!")
475
- torch.cuda.empty_cache()
476
- return high_res_image, up_img, image, query_image_bw
477
-
478
- with gr.Blocks() as demo:
479
- gr.HTML(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
480
  """
481
- <div style="text-align: center;">
482
- <h1 style="text-align: center; font-size: 3em;">🎨 ColorFlow:</h1>
483
- <h3 style="text-align: center; font-size: 1.8em;">Retrieval-Augmented Image Sequence Colorization</h3>
484
- <p style="text-align: center; font-weight: bold;">
485
- <a href="https://zhuang2002.github.io/ColorFlow/">Project Page</a> |
486
- <a href="https://arxiv.org/abs/2412.11815">ArXiv Preprint</a> |
487
- <a href="https://github.com/TencentARC/ColorFlow">GitHub Repository</a>
488
- </p>
489
- <p style="text-align: center; font-weight: bold;">
490
- NOTE: Each time you switch the input style, the corresponding model will be reloaded, which may take some time. Please be patient.
491
- </p>
492
- <p style="text-align: left; font-size: 1.1em;">
493
- Welcome to the demo of <strong>ColorFlow</strong>. Follow the steps below to explore the capabilities of our model:
494
- </p>
495
- </div>
496
- <div style="text-align: left; margin: 0 auto;">
497
- <ol style="font-size: 1.1em;">
498
- <li>Choose input style: GrayImage(ScreenStyle)、Sketch with Shading or Sketch.</li>
499
- <li>Upload your image: Use the 'Upload' button to select the image you want to colorize.</li>
500
- <li>Preprocess the image: Click the 'Preprocess' button to decolorize the image.</li>
501
- <li>Upload reference images: Upload multiple reference images to guide the colorization.</li>
502
- <li>Set sampling parameters (optional): Adjust the settings and click the <b>Colorize</b> button.</li>
503
- </ol>
504
- <p>
505
- ⏱️ <b>ZeroGPU Time Limit</b>: Hugging Face ZeroGPU has an inference time limit of 180 seconds. You may need to log in with a free account to use this demo. Large sampling steps might lead to timeout (GPU Abort). In that case, please consider logging in with a Pro account or running it on your local machine.
506
- </p>
507
- </div>
508
- <div style="text-align: center;">
509
- <p style="text-align: center; font-weight: bold;">
510
- 注意:每次切换输入样式时,相应的模型将被重新加载,可能需要一些时间。请耐心等待。
511
- </p>
512
- <p style="text-align: left; font-size: 1.1em;">
513
- 欢迎使用 <strong>ColorFlow</strong> 演示。请按照以下步骤探索我们模型的能力:
514
- </p>
515
- </div>
516
- <div style="text-align: left; margin: 0 auto;">
517
- <ol style="font-size: 1.1em;">
518
- <li>选择输入样式:灰度图(ScreenStyle)、线稿+阴影、线稿。</li>
519
- <li>上传您的图像:使用“上传”按钮选择要上色的图像。</li>
520
- <li>预处理图像:点击“预处理”按钮以去色图像。</li>
521
- <li>上传参考图像:上传多张参考图像以指导上色。</li>
522
- <li>设置采样参数(可选):调整设置并点击 <b>上色</b> 按钮。</li>
523
- </ol>
524
- <p>
525
- ⏱️ <b>ZeroGPU时间限制</b>:Hugging Face ZeroGPU 的推理时间限制为 180 秒。您可能需要使用免费帐户登录以使用此演示。大采样步骤可能会导致超时(GPU 中止)。在这种情况下,请考虑使用专业帐户登录或在本地计算机上运行。
526
- </p>
527
- </div>
528
  """
529
- )
530
- VAE_input = gr.State()
531
- input_context = gr.State()
532
- # example_loading = gr.State(value=None)
533
-
534
- with gr.Column():
535
- with gr.Row():
536
- input_style = gr.Radio(["GrayImage(ScreenStyle)", "Sketch_Shading", "Sketch"], label="Input Style", value="GrayImage(ScreenStyle)")
537
- with gr.Row():
538
- with gr.Column():
539
- input_image = gr.Image(type="pil", label="Image to Colorize")
540
- resolution = gr.Radio(["640x640", "512x800", "800x512"], label="Select Resolution(Width*Height)", value="640x640")
541
- extract_button = gr.Button("Preprocess (Decolorize)")
542
- extracted_image = gr.Image(type="pil", label="Decolorized Result")
543
- with gr.Row():
544
- reference_images = gr.Files(label="Reference Images (Upload multiple)", file_count="multiple")
545
- with gr.Column():
546
- output_gallery = gr.Gallery(label="Colorization Results", type="pil")
547
- seed = gr.Slider(label="Random Seed", minimum=0, maximum=100000, value=0, step=1)
548
- num_inference_steps = gr.Slider(label="Inference Steps", minimum=4, maximum=100, value=10, step=1)
549
- colorize_button = gr.Button("Colorize")
550
-
551
- # progress_text = gr.Textbox(label="Progress", interactive=False)
552
-
553
-
554
- extract_button.click(
555
- extract_line_image,
556
- inputs=[input_image, input_style, resolution],
557
- outputs=[extracted_image, VAE_input, input_context]
558
- )
559
- colorize_button.click(
560
- colorize_image,
561
- inputs=[VAE_input, input_context, reference_images, resolution, seed, input_style, num_inference_steps],
562
- outputs=output_gallery
563
- )
564
-
565
- with gr.Column():
566
- gr.Markdown("### Quick Examples")
567
- gr.Examples(
568
- examples=examples,
569
- inputs=[input_image, reference_images, input_style, resolution, seed, num_inference_steps],
570
- label="Examples",
571
- examples_per_page=8,
572
- )
573
- gr.HTML('<a href="https://github.com/TencentARC/ColorFlow"><img src="https://img.shields.io/github/stars/TencentARC/ColorFlow" alt="GitHub Stars"></a>')
574
- gr.Markdown(article)
575
- # gr.HTML(
576
- # '<a href="https://github.com/TencentARC/ColorFlow"><img src="https://img.shields.io/github/stars/TencentARC/ColorFlow" alt="GitHub Stars"></a>'
577
- # )
578
-
579
-
580
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import torch.nn as nn
3
+ from torch.nn import init
4
+ import functools
5
+ from torch.optim import lr_scheduler
6
+ import numpy as np
7
  import torch.nn.functional as F
8
+ from torch.nn.modules.normalization import LayerNorm
9
+ import os
10
+ from torch.nn.utils import spectral_norm
11
+ from torchvision import models
12
+
13
+ ###############################################################################
14
+ # Helper functions
15
+ ###############################################################################
16
+
17
+
18
+ def init_weights(net, init_type='normal', init_gain=0.02):
19
+ """Initialize network weights.
20
+ Parameters:
21
+ net (network) -- network to be initialized
22
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
23
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
24
+ We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
25
+ work better for some applications. Feel free to try yourself.
26
+ """
27
+ def init_func(m): # define the initialization function
28
+ classname = m.__class__.__name__
29
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
30
+ if init_type == 'normal':
31
+ init.normal_(m.weight.data, 0.0, init_gain)
32
+ elif init_type == 'xavier':
33
+ init.xavier_normal_(m.weight.data, gain=init_gain)
34
+ elif init_type == 'kaiming':
35
+ #init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
36
+ init.kaiming_normal_(m.weight.data, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
37
+ elif init_type == 'orthogonal':
38
+ init.orthogonal_(m.weight.data, gain=init_gain)
39
+ else:
40
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
41
+ if hasattr(m, 'bias') and m.bias is not None:
42
+ init.constant_(m.bias.data, 0.0)
43
+ elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
44
+ init.normal_(m.weight.data, 1.0, init_gain)
45
+ init.constant_(m.bias.data, 0.0)
46
+
47
+ print('initialize network with %s' % init_type)
48
+ net.apply(init_func) # apply the initialization function <init_func>
49
+
50
+
51
+ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], init=True):
52
+ """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
53
+ Parameters:
54
+ net (network) -- the network to be initialized
55
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
56
+ gain (float) -- scaling factor for normal, xavier and orthogonal.
57
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
58
+ Return an initialized network.
59
+ """
60
+ if len(gpu_ids) > 0 and torch.cuda.is_available():
61
+ net.to(gpu_ids[0])
62
+ if init:
63
+ init_weights(net, init_type, init_gain=init_gain)
64
+ return net
65
+
66
+
67
+ def get_scheduler(optimizer, opt):
68
+ """Return a learning rate scheduler
69
+ Parameters:
70
+ optimizer -- the optimizer of the network
71
+ opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 
72
+ opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
73
+ For 'linear', we keep the same learning rate for the first <opt.niter> epochs
74
+ and linearly decay the rate to zero over the next <opt.niter_decay> epochs.
75
+ For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
76
+ See https://pytorch.org/docs/stable/optim.html for more details.
77
+ """
78
+ if opt.lr_policy == 'linear':
79
+ def lambda_rule(epoch):
80
+ lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
81
+ return lr_l
82
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
83
+ elif opt.lr_policy == 'step':
84
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
85
+ elif opt.lr_policy == 'plateau':
86
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
87
+ elif opt.lr_policy == 'cosine':
88
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
89
+ else:
90
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
91
+ return scheduler
92
+
93
+ class LayerNormWarpper(nn.Module):
94
+ def __init__(self, num_features):
95
+ super(LayerNormWarpper, self).__init__()
96
+ self.num_features = int(num_features)
97
+
98
+ def forward(self, x):
99
+ x = nn.LayerNorm([self.num_features, x.size()[2], x.size()[3]], elementwise_affine=False).to(x.device)(x)
100
+ return x
101
+
102
+ def get_norm_layer(norm_type='instance'):
103
+ """Return a normalization layer
104
+ Parameters:
105
+ norm_type (str) -- the name of the normalization layer: batch | instance | none
106
+ For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
107
+ For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
108
+ """
109
+ if norm_type == 'batch':
110
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
111
+ elif norm_type == 'instance':
112
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
113
+ elif norm_type == 'layer':
114
+ norm_layer = functools.partial(LayerNormWarpper)
115
+ elif norm_type == 'none':
116
+ norm_layer = None
117
+ else:
118
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
119
+ return norm_layer
120
+
121
+
122
+ def get_non_linearity(layer_type='relu'):
123
+ if layer_type == 'relu':
124
+ nl_layer = functools.partial(nn.ReLU, inplace=True)
125
+ elif layer_type == 'lrelu':
126
+ nl_layer = functools.partial(
127
+ nn.LeakyReLU, negative_slope=0.2, inplace=True)
128
+ elif layer_type == 'elu':
129
+ nl_layer = functools.partial(nn.ELU, inplace=True)
130
+ elif layer_type == 'selu':
131
+ nl_layer = functools.partial(nn.SELU, inplace=True)
132
+ elif layer_type == 'prelu':
133
+ nl_layer = functools.partial(nn.PReLU)
134
+ else:
135
+ raise NotImplementedError(
136
+ 'nonlinearity activitation [%s] is not found' % layer_type)
137
+ return nl_layer
138
+
139
+
140
+ def define_G(input_nc, output_nc, nz, ngf, netG='unet_128', norm='batch', nl='relu', use_noise=False,
141
+ use_dropout=False, init_type='xavier', init_gain=0.02, gpu_ids=[], where_add='input', upsample='bilinear'):
142
+ net = None
143
+ norm_layer = get_norm_layer(norm_type=norm)
144
+ nl_layer = get_non_linearity(layer_type=nl)
145
+ # print(norm, norm_layer)
146
+
147
+ if nz == 0:
148
+ where_add = 'input'
149
+
150
+ if netG == 'unet_128' and where_add == 'input':
151
+ net = G_Unet_add_input(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
152
+ use_dropout=use_dropout, upsample=upsample, device=gpu_ids)
153
+ elif netG == 'unet_128_G' and where_add == 'input':
154
+ net = G_Unet_add_input_G(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
155
+ use_dropout=use_dropout, upsample=upsample, device=gpu_ids)
156
+ elif netG == 'unet_256' and where_add == 'input':
157
+ net = G_Unet_add_input(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
158
+ use_dropout=use_dropout, upsample=upsample, device=gpu_ids)
159
+ elif netG == 'unet_256_G' and where_add == 'input':
160
+ net = G_Unet_add_input_G(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
161
+ use_dropout=use_dropout, upsample=upsample, device=gpu_ids)
162
+ elif netG == 'unet_128' and where_add == 'all':
163
+ net = G_Unet_add_all(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
164
+ use_dropout=use_dropout, upsample=upsample)
165
+ elif netG == 'unet_256' and where_add == 'all':
166
+ net = G_Unet_add_all(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
167
+ use_dropout=use_dropout, upsample=upsample)
168
+ else:
169
+ raise NotImplementedError('Generator model name [%s] is not recognized' % net)
170
+ # print(net)
171
+ return init_net(net, init_type, init_gain, gpu_ids)
172
+
173
+
174
+ def define_C(input_nc, output_nc, nz, ngf, netC='unet_128', norm='instance', nl='relu',
175
+ use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], upsample='basic'):
176
+ net = None
177
+ norm_layer = get_norm_layer(norm_type=norm)
178
+ nl_layer = get_non_linearity(layer_type=nl)
179
+
180
+ if netC == 'resnet_9blocks':
181
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
182
+ elif netC == 'resnet_6blocks':
183
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
184
+ elif netC == 'unet_128':
185
+ net = G_Unet_add_input_C(input_nc, output_nc, 0, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
186
+ use_dropout=use_dropout, upsample=upsample)
187
+ elif netC == 'unet_256':
188
+ net = G_Unet_add_input(input_nc, output_nc, 0, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
189
+ use_dropout=use_dropout, upsample=upsample)
190
+ elif netC == 'unet_32':
191
+ net = G_Unet_add_input(input_nc, output_nc, 0, 5, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
192
+ use_dropout=use_dropout, upsample=upsample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  else:
194
+ raise NotImplementedError('Generator model name [%s] is not recognized' % net)
195
+
196
+ return init_net(net, init_type, init_gain, gpu_ids)
197
+
198
+
199
+ def define_D(input_nc, ndf, netD, norm='batch', nl='lrelu', init_type='xavier', init_gain=0.02, num_Ds=1, gpu_ids=[]):
200
+ net = None
201
+ norm_layer = get_norm_layer(norm_type=norm)
202
+ nl = 'lrelu' # use leaky relu for D
203
+ nl_layer = get_non_linearity(layer_type=nl)
204
+
205
+ if netD == 'basic_128':
206
+ net = D_NLayers(input_nc, ndf, n_layers=2, norm_layer=norm_layer, nl_layer=nl_layer)
207
+ elif netD == 'basic_256':
208
+ net = D_NLayers(input_nc, ndf, n_layers=3, norm_layer=norm_layer, nl_layer=nl_layer)
209
+ elif netD == 'basic_128_multi':
210
+ net = D_NLayersMulti(input_nc=input_nc, ndf=ndf, n_layers=2, norm_layer=norm_layer, num_D=num_Ds, nl_layer=nl_layer)
211
+ elif netD == 'basic_256_multi':
212
+ net = D_NLayersMulti(input_nc=input_nc, ndf=ndf, n_layers=3, norm_layer=norm_layer, num_D=num_Ds, nl_layer=nl_layer)
213
+ else:
214
+ raise NotImplementedError('Discriminator model name [%s] is not recognized' % net)
215
+ return init_net(net, init_type, init_gain, gpu_ids)
216
+
217
+
218
+ def define_E(input_nc, output_nc, ndf, netE, norm='batch', nl='lrelu',
219
+ init_type='xavier', init_gain=0.02, gpu_ids=[], vaeLike=False):
220
+ net = None
221
+ norm_layer = get_norm_layer(norm_type=norm)
222
+ nl = 'lrelu' # use leaky relu for E
223
+ nl_layer = get_non_linearity(layer_type=nl)
224
+ if netE == 'resnet_128':
225
+ net = E_ResNet(input_nc, output_nc, ndf, n_blocks=4, norm_layer=norm_layer,
226
+ nl_layer=nl_layer, vaeLike=vaeLike)
227
+ elif netE == 'resnet_256':
228
+ net = E_ResNet(input_nc, output_nc, ndf, n_blocks=5, norm_layer=norm_layer,
229
+ nl_layer=nl_layer, vaeLike=vaeLike)
230
+ elif netE == 'conv_128':
231
+ net = E_NLayers(input_nc, output_nc, ndf, n_layers=4, norm_layer=norm_layer,
232
+ nl_layer=nl_layer, vaeLike=vaeLike)
233
+ elif netE == 'conv_256':
234
+ net = E_NLayers(input_nc, output_nc, ndf, n_layers=5, norm_layer=norm_layer,
235
+ nl_layer=nl_layer, vaeLike=vaeLike)
 
 
 
 
 
 
 
 
236
  else:
237
+ raise NotImplementedError('Encoder model name [%s] is not recognized' % net)
238
+
239
+ return init_net(net, init_type, init_gain, gpu_ids, False)
240
+
241
+
242
+ class ResnetGenerator(nn.Module):
243
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, norm_layer=None, use_dropout=False, n_blocks=6, padding_type='replicate'):
244
+ assert(n_blocks >= 0)
245
+ super(ResnetGenerator, self).__init__()
246
+ self.input_nc = input_nc
247
+ self.output_nc = output_nc
248
+ self.ngf = ngf
249
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
250
+ use_bias = norm_layer.func != nn.BatchNorm2d
251
+ else:
252
+ use_bias = norm_layer != nn.BatchNorm2d
253
+
254
+ model = [nn.ReplicationPad2d(3),
255
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
256
+ bias=use_bias)]
257
+ if norm_layer is not None:
258
+ model += [norm_layer(ngf)]
259
+ model += [nn.ReLU(True)]
260
+
261
+ # n_downsampling = 2
262
+ for i in range(n_downsampling):
263
+ mult = 2**i
264
+ model += [nn.ReplicationPad2d(1),nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
265
+ stride=2, padding=0, bias=use_bias)]
266
+ # model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
267
+ # stride=2, padding=1, bias=use_bias)]
268
+ if norm_layer is not None:
269
+ model += [norm_layer(ngf * mult * 2)]
270
+ model += [nn.ReLU(True)]
271
+
272
+ mult = 2**n_downsampling
273
+ for i in range(n_blocks):
274
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
275
+
276
+ for i in range(n_downsampling):
277
+ mult = 2**(n_downsampling - i)
278
+ # model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
279
+ # kernel_size=3, stride=2,
280
+ # padding=1, output_padding=1,
281
+ # bias=use_bias)]
282
+ # if norm_layer is not None:
283
+ # model += [norm_layer(ngf * mult / 2)]
284
+ # model += [nn.ReLU(True)]
285
+ model += upsampleLayer(ngf * mult, int(ngf * mult / 2), upsample='bilinear', padding_type=padding_type)
286
+ if norm_layer is not None:
287
+ model += [norm_layer(int(ngf * mult / 2))]
288
+ model += [nn.ReLU(True)]
289
+ model +=[nn.ReplicationPad2d(1),
290
+ nn.Conv2d(int(ngf * mult / 2), int(ngf * mult / 2), kernel_size=3, padding=0)]
291
+ if norm_layer is not None:
292
+ model += [norm_layer(ngf * mult / 2)]
293
+ model += [nn.ReLU(True)]
294
+ model += [nn.ReplicationPad2d(3)]
295
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
296
+ #model += [nn.Tanh()]
297
+
298
+ self.model = nn.Sequential(*model)
299
+
300
+ def forward(self, input):
301
+ return self.model(input)
302
+
303
+
304
+ # Define a resnet block
305
+ class ResnetBlock(nn.Module):
306
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
307
+ super(ResnetBlock, self).__init__()
308
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
309
+
310
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
311
+ conv_block = []
312
+ p = 0
313
+ if padding_type == 'reflect':
314
+ conv_block += [nn.ReflectionPad2d(1)]
315
+ elif padding_type == 'replicate':
316
+ conv_block += [nn.ReplicationPad2d(1)]
317
+ elif padding_type == 'zero':
318
+ p = 1
319
+ else:
320
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
321
+
322
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)]
323
+ if norm_layer is not None:
324
+ conv_block += [norm_layer(dim)]
325
+ conv_block += [nn.ReLU(True)]
326
+ # if use_dropout:
327
+ # conv_block += [nn.Dropout(0.5)]
328
+
329
+ p = 0
330
+ if padding_type == 'reflect':
331
+ conv_block += [nn.ReflectionPad2d(1)]
332
+ elif padding_type == 'replicate':
333
+ conv_block += [nn.ReplicationPad2d(1)]
334
+ elif padding_type == 'zero':
335
+ p = 1
336
+ else:
337
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
338
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)]
339
+ if norm_layer is not None:
340
+ conv_block += [norm_layer(dim)]
341
+
342
+ return nn.Sequential(*conv_block)
343
+
344
+ def forward(self, x):
345
+ out = x + self.conv_block(x)
346
+ return out
347
+
348
+
349
+ class D_NLayersMulti(nn.Module):
350
+ def __init__(self, input_nc, ndf=64, n_layers=3,
351
+ norm_layer=nn.BatchNorm2d, num_D=1, nl_layer=None):
352
+ super(D_NLayersMulti, self).__init__()
353
+ # st()
354
+ self.num_D = num_D
355
+ self.nl_layer=nl_layer
356
+ if num_D == 1:
357
+ layers = self.get_layers(input_nc, ndf, n_layers, norm_layer)
358
+ self.model = nn.Sequential(*layers)
359
+ else:
360
+ layers = self.get_layers(input_nc, ndf, n_layers, norm_layer)
361
+ self.add_module("model_0", nn.Sequential(*layers))
362
+ self.down = nn.functional.interpolate
363
+ for i in range(1, num_D):
364
+ ndf_i = int(round(ndf / (2**i)))
365
+ layers = self.get_layers(input_nc, ndf_i, n_layers, norm_layer)
366
+ self.add_module("model_%d" % i, nn.Sequential(*layers))
367
+
368
+ def get_layers(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
369
+ kw = 3
370
+ padw = 1
371
+ sequence = [spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw,
372
+ stride=2, padding=padw)), nn.LeakyReLU(0.2, True)]
373
+
374
+ nf_mult = 1
375
+ nf_mult_prev = 1
376
+ for n in range(1, n_layers):
377
+ nf_mult_prev = nf_mult
378
+ nf_mult = min(2**n, 8)
379
+ sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
380
+ kernel_size=kw, stride=2, padding=padw))]
381
+ if norm_layer:
382
+ sequence += [norm_layer(ndf * nf_mult)]
383
+
384
+ sequence += [self.nl_layer()]
385
+
386
+ nf_mult_prev = nf_mult
387
+ nf_mult = min(2**n_layers, 8)
388
+ sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
389
+ kernel_size=kw, stride=1, padding=padw))]
390
+ if norm_layer:
391
+ sequence += [norm_layer(ndf * nf_mult)]
392
+ sequence += [self.nl_layer()]
393
+
394
+ sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult, 1,
395
+ kernel_size=kw, stride=1, padding=padw))]
396
+
397
+ return sequence
398
+
399
+ def forward(self, input):
400
+ if self.num_D == 1:
401
+ return self.model(input)
402
+ result = []
403
+ down = input
404
+ for i in range(self.num_D):
405
+ model = getattr(self, "model_%d" % i)
406
+ result.append(model(down))
407
+ if i != self.num_D - 1:
408
+ down = self.down(down, scale_factor=0.5, mode='bilinear')
409
+ return result
410
+
411
+ class D_NLayers(nn.Module):
412
+ """Defines a PatchGAN discriminator"""
413
+
414
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
415
+ """Construct a PatchGAN discriminator
416
+ Parameters:
417
+ input_nc (int) -- the number of channels in input images
418
+ ndf (int) -- the number of filters in the last conv layer
419
+ n_layers (int) -- the number of conv layers in the discriminator
420
+ norm_layer -- normalization layer
421
+ """
422
+ super(D_NLayers, self).__init__()
423
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
424
+ use_bias = norm_layer.func != nn.BatchNorm2d
425
+ else:
426
+ use_bias = norm_layer != nn.BatchNorm2d
427
+
428
+ kw = 3
429
+ padw = 1
430
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
431
+ nf_mult = 1
432
+ nf_mult_prev = 1
433
+ for n in range(1, n_layers): # gradually increase the number of filters
434
+ nf_mult_prev = nf_mult
435
+ nf_mult = min(2 ** n, 8)
436
+ sequence += [
437
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
438
+ norm_layer(ndf * nf_mult),
439
+ nn.LeakyReLU(0.2, True)
440
+ ]
441
+
442
+ nf_mult_prev = nf_mult
443
+ nf_mult = min(2 ** n_layers, 8)
444
+ sequence += [
445
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
446
+ norm_layer(ndf * nf_mult),
447
+ nn.LeakyReLU(0.2, True)
448
+ ]
449
+
450
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
451
+ self.model = nn.Sequential(*sequence)
452
+
453
+ def forward(self, input):
454
+ """Standard forward."""
455
+ return self.model(input)
456
+
457
+
458
+ class G_Unet_add_input(nn.Module):
459
+ def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64,
460
+ norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False,
461
+ upsample='basic', device=0):
462
+ super(G_Unet_add_input, self).__init__()
463
+ self.nz = nz
464
+ max_nchn = 8
465
+ noise = []
466
+ for i in range(num_downs+1):
467
+ if use_noise:
468
+ noise.append(True)
469
+ else:
470
+ noise.append(False)
471
+
472
+ # construct unet structure
473
+ #print(num_downs)
474
+ unet_block = UnetBlock_A(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, noise=noise[num_downs-1],
475
+ innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
476
+ for i in range(num_downs - 5):
477
+ unet_block = UnetBlock_A(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block, noise[num_downs-i-3],
478
+ norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
479
+ unet_block = UnetBlock_A(ngf * 4, ngf * 4, ngf * max_nchn, unet_block, noise[2],
480
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
481
+ unet_block = UnetBlock_A(ngf * 2, ngf * 2, ngf * 4, unet_block, noise[1],
482
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
483
+ unet_block = UnetBlock_A(ngf, ngf, ngf * 2, unet_block, noise[0],
484
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
485
+ unet_block = UnetBlock_A(input_nc + nz, output_nc, ngf, unet_block, None,
486
+ outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
487
+
488
+ self.model = unet_block
489
+
490
+ def forward(self, x, z=None):
491
+ if self.nz > 0:
492
+ z_img = z.view(z.size(0), z.size(1), 1, 1).expand(
493
+ z.size(0), z.size(1), x.size(2), x.size(3))
494
+ x_with_z = torch.cat([x, z_img], 1)
495
+ else:
496
+ x_with_z = x # no z
497
+
498
+
499
+ return torch.tanh(self.model(x_with_z))
500
+ # return self.model(x_with_z)
501
+
502
+ class G_Unet_add_input_G(nn.Module):
503
+ def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64,
504
+ norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False,
505
+ upsample='basic', device=0):
506
+ super(G_Unet_add_input_G, self).__init__()
507
+ self.nz = nz
508
+ max_nchn = 8
509
+ noise = []
510
+ for i in range(num_downs+1):
511
+ if use_noise:
512
+ noise.append(True)
513
+ else:
514
+ noise.append(False)
515
+ # construct unet structure
516
+ #print(num_downs)
517
+ unet_block = UnetBlock_G(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, noise=False,
518
+ innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
519
+ for i in range(num_downs - 5):
520
+ unet_block = UnetBlock_G(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block, noise=False,
521
+ norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
522
+ unet_block = UnetBlock_G(ngf * 4, ngf * 4, ngf * max_nchn, unet_block, noise[2],
523
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic')
524
+ unet_block = UnetBlock_G(ngf * 2, ngf * 2, ngf * 4, unet_block, noise[1],
525
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic')
526
+ unet_block = UnetBlock_G(ngf, ngf, ngf * 2, unet_block, noise[0],
527
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic')
528
+ unet_block = UnetBlock_G(input_nc + nz, output_nc, ngf, unet_block, None,
529
+ outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic')
530
+
531
+ self.model = unet_block
532
+
533
+ def forward(self, x, z=None):
534
+ if self.nz > 0:
535
+ z_img = z.view(z.size(0), z.size(1), 1, 1).expand(
536
+ z.size(0), z.size(1), x.size(2), x.size(3))
537
+ x_with_z = torch.cat([x, z_img], 1)
538
+ else:
539
+ x_with_z = x # no z
540
+
541
+ # return F.tanh(self.model(x_with_z))
542
+ return self.model(x_with_z)
543
+
544
+ class G_Unet_add_input_C(nn.Module):
545
+ def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64,
546
+ norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False,
547
+ upsample='basic', device=0):
548
+ super(G_Unet_add_input_C, self).__init__()
549
+ self.nz = nz
550
+ max_nchn = 8
551
+ # construct unet structure
552
+ #print(num_downs)
553
+ unet_block = UnetBlock(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, noise=False,
554
+ innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
555
+ for i in range(num_downs - 5):
556
+ unet_block = UnetBlock(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block, noise=False,
557
+ norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
558
+ unet_block = UnetBlock(ngf * 4, ngf * 4, ngf * max_nchn, unet_block, noise=False,
559
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
560
+ unet_block = UnetBlock(ngf * 2, ngf * 2, ngf * 4, unet_block, noise=False,
561
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
562
+ unet_block = UnetBlock(ngf, ngf, ngf * 2, unet_block, noise=False,
563
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
564
+ unet_block = UnetBlock(input_nc + nz, output_nc, ngf, unet_block, noise=False,
565
+ outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
566
+
567
+ self.model = unet_block
568
+
569
+ def forward(self, x, z=None):
570
+ if self.nz > 0:
571
+ z_img = z.view(z.size(0), z.size(1), 1, 1).expand(
572
+ z.size(0), z.size(1), x.size(2), x.size(3))
573
+ x_with_z = torch.cat([x, z_img], 1)
574
+ else:
575
+ x_with_z = x # no z
576
+
577
+ # return torch.tanh(self.model(x_with_z))
578
+ return self.model(x_with_z)
579
+
580
+ def upsampleLayer(inplanes, outplanes, kw=1, upsample='basic', padding_type='replicate'):
581
+ # padding_type = 'zero'
582
+ if upsample == 'basic':
583
+ upconv = [nn.ConvTranspose2d(inplanes, outplanes, kernel_size=4, stride=2, padding=1)]#, padding_mode='replicate'
584
+ elif upsample == 'bilinear' or upsample == 'nearest' or upsample == 'linear':
585
+ upconv = [nn.Upsample(scale_factor=2, mode=upsample, align_corners=True),
586
+ #nn.ReplicationPad2d(1),
587
+ nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0)]
588
+ # p = kw//2
589
+ # upconv = [nn.Upsample(scale_factor=2, mode=upsample, align_corners=True),
590
+ # nn.Conv2d(inplanes, outplanes, kernel_size=kw, stride=1, padding=p, padding_mode='replicate')]
591
+ else:
592
+ raise NotImplementedError(
593
+ 'upsample layer [%s] not implemented' % upsample)
594
+ return upconv
595
+
596
+ class UnetBlock_G(nn.Module):
597
+ def __init__(self, input_nc, outer_nc, inner_nc,
598
+ submodule=None, noise=None, outermost=False, innermost=False,
599
+ norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='replicate'):
600
+ super(UnetBlock_G, self).__init__()
601
+ self.outermost = outermost
602
+ p = 0
603
+ downconv = []
604
+ if padding_type == 'reflect':
605
+ downconv += [nn.ReflectionPad2d(1)]
606
+ elif padding_type == 'replicate':
607
+ downconv += [nn.ReplicationPad2d(1)]
608
+ elif padding_type == 'zero':
609
+ p = 1
610
+ else:
611
+ raise NotImplementedError(
612
+ 'padding [%s] is not implemented' % padding_type)
613
+
614
+ downconv += [nn.Conv2d(input_nc, inner_nc,
615
+ kernel_size=3, stride=2, padding=p)]
616
+ # downsample is different from upsample
617
+ downrelu = nn.LeakyReLU(0.2, True)
618
+ downnorm = norm_layer(inner_nc) if norm_layer is not None else None
619
+ uprelu = nl_layer()
620
+ uprelu2 = nl_layer()
621
+ uppad = nn.ReplicationPad2d(1)
622
+ upnorm = norm_layer(outer_nc) if norm_layer is not None else None
623
+ upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None
624
+ self.noiseblock = ApplyNoise(outer_nc)
625
+ self.noise = noise
626
+
627
+ if outermost:
628
+ upconv = upsampleLayer(inner_nc * 2, inner_nc, upsample=upsample, padding_type=padding_type)
629
+ uppad = nn.ReplicationPad2d(3)
630
+ upconv2 = nn.Conv2d(inner_nc, outer_nc, kernel_size=7, padding=0)
631
+ down = downconv
632
+ up = [uprelu] + upconv
633
+ if upnorm is not None:
634
+ up += [norm_layer(inner_nc)]
635
+ # upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
636
+ # upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=0)
637
+ # down = downconv
638
+ # up = [uprelu] + upconv
639
+ # if upnorm is not None:
640
+ # up += [norm_layer(outer_nc)]
641
+ up +=[uprelu2, uppad, upconv2] #+ [nn.Tanh()]
642
+ model = down + [submodule] + up
643
+ elif innermost:
644
+ upconv = upsampleLayer(inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
645
+ upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
646
+ down = [downrelu] + downconv
647
+ up = [uprelu] + upconv
648
+ if upnorm is not None:
649
+ up += [upnorm]
650
+ up += [uprelu2, uppad, upconv2]
651
+ if upnorm2 is not None:
652
+ up += [upnorm2]
653
+ model = down + up
654
+ else:
655
+ upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
656
+ upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
657
+ down = [downrelu] + downconv
658
+ if downnorm is not None:
659
+ down += [downnorm]
660
+ up = [uprelu] + upconv
661
+ if upnorm is not None:
662
+ up += [upnorm]
663
+ up += [uprelu2, uppad, upconv2]
664
+ if upnorm2 is not None:
665
+ up += [upnorm2]
666
+
667
+ if use_dropout:
668
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
669
+ else:
670
+ model = down + [submodule] + up
671
+
672
+ self.model = nn.Sequential(*model)
673
+
674
+ def forward(self, x):
675
+ if self.outermost:
676
+ return self.model(x)
677
+ else:
678
+ x2 = self.model(x)
679
+ if self.noise:
680
+ x2 = self.noiseblock(x2, self.noise)
681
+ return torch.cat([x2, x], 1)
682
+
683
+
684
+ class UnetBlock(nn.Module):
685
+ def __init__(self, input_nc, outer_nc, inner_nc,
686
+ submodule=None, noise=None, outermost=False, innermost=False,
687
+ norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='replicate'):
688
+ super(UnetBlock, self).__init__()
689
+ self.outermost = outermost
690
+ p = 0
691
+ downconv = []
692
+ if padding_type == 'reflect':
693
+ downconv += [nn.ReflectionPad2d(1)]
694
+ elif padding_type == 'replicate':
695
+ downconv += [nn.ReplicationPad2d(1)]
696
+ elif padding_type == 'zero':
697
+ p = 1
698
+ else:
699
+ raise NotImplementedError(
700
+ 'padding [%s] is not implemented' % padding_type)
701
+
702
+ downconv += [nn.Conv2d(input_nc, inner_nc,
703
+ kernel_size=3, stride=2, padding=p)]
704
+ # downsample is different from upsample
705
+ downrelu = nn.LeakyReLU(0.2, True)
706
+ downnorm = norm_layer(inner_nc) if norm_layer is not None else None
707
+ uprelu = nl_layer()
708
+ uprelu2 = nl_layer()
709
+ uppad = nn.ReplicationPad2d(1)
710
+ upnorm = norm_layer(outer_nc) if norm_layer is not None else None
711
+ upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None
712
+ self.noiseblock = ApplyNoise(outer_nc)
713
+ self.noise = noise
714
+
715
+ if outermost:
716
+ upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
717
+ upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
718
+ down = downconv
719
+ up = [uprelu] + upconv
720
+ if upnorm is not None:
721
+ up += [upnorm]
722
+ up +=[uprelu2, uppad, upconv2] #+ [nn.Tanh()]
723
+ model = down + [submodule] + up
724
+ elif innermost:
725
+ upconv = upsampleLayer(inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
726
+ upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
727
+ down = [downrelu] + downconv
728
+ up = [uprelu] + upconv
729
+ if upnorm is not None:
730
+ up += [upnorm]
731
+ up += [uprelu2, uppad, upconv2]
732
+ if upnorm2 is not None:
733
+ up += [upnorm2]
734
+ model = down + up
735
+ else:
736
+ upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
737
+ upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
738
+ down = [downrelu] + downconv
739
+ if downnorm is not None:
740
+ down += [downnorm]
741
+ up = [uprelu] + upconv
742
+ if upnorm is not None:
743
+ up += [upnorm]
744
+ up += [uprelu2, uppad, upconv2]
745
+ if upnorm2 is not None:
746
+ up += [upnorm2]
747
+
748
+ if use_dropout:
749
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
750
+ else:
751
+ model = down + [submodule] + up
752
+
753
+ self.model = nn.Sequential(*model)
754
+
755
+ def forward(self, x):
756
+ if self.outermost:
757
+ return self.model(x)
758
+ else:
759
+ x2 = self.model(x)
760
+ if self.noise:
761
+ x2 = self.noiseblock(x2, self.noise)
762
+ return torch.cat([x2, x], 1)
763
+
764
+ # Defines the submodule with skip connection.
765
+ # X -------------------identity---------------------- X
766
+ # |-- downsampling -- |submodule| -- upsampling --|
767
+ class UnetBlock_A(nn.Module):
768
+ def __init__(self, input_nc, outer_nc, inner_nc,
769
+ submodule=None, noise=None, outermost=False, innermost=False,
770
+ norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='replicate'):
771
+ super(UnetBlock_A, self).__init__()
772
+ self.outermost = outermost
773
+ p = 0
774
+ downconv = []
775
+ if padding_type == 'reflect':
776
+ downconv += [nn.ReflectionPad2d(1)]
777
+ elif padding_type == 'replicate':
778
+ downconv += [nn.ReplicationPad2d(1)]
779
+ elif padding_type == 'zero':
780
+ p = 1
781
+ else:
782
+ raise NotImplementedError(
783
+ 'padding [%s] is not implemented' % padding_type)
784
+
785
+ downconv += [spectral_norm(nn.Conv2d(input_nc, inner_nc,
786
+ kernel_size=3, stride=2, padding=p))]
787
+ # downsample is different from upsample
788
+ downrelu = nn.LeakyReLU(0.2, True)
789
+ downnorm = norm_layer(inner_nc) if norm_layer is not None else None
790
+ uprelu = nl_layer()
791
+ uprelu2 = nl_layer()
792
+ uppad = nn.ReplicationPad2d(1)
793
+ upnorm = norm_layer(outer_nc) if norm_layer is not None else None
794
+ upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None
795
+ self.noiseblock = ApplyNoise(outer_nc)
796
+ self.noise = noise
797
+
798
+ if outermost:
799
+ upconv = upsampleLayer(inner_nc * 1, outer_nc, upsample=upsample, padding_type=padding_type)
800
+ upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
801
+ down = downconv
802
+ up = [uprelu] + upconv
803
+ if upnorm is not None:
804
+ up += [upnorm]
805
+ up +=[uprelu2, uppad, upconv2] #+ [nn.Tanh()]
806
+ model = down + [submodule] + up
807
+ elif innermost:
808
+ upconv = upsampleLayer(inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
809
+ upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
810
+ down = [downrelu] + downconv
811
+ up = [uprelu] + upconv
812
+ if upnorm is not None:
813
+ up += [upnorm]
814
+ up += [uprelu2, uppad, upconv2]
815
+ if upnorm2 is not None:
816
+ up += [upnorm2]
817
+ model = down + up
818
+ else:
819
+ upconv = upsampleLayer(inner_nc * 1, outer_nc, upsample=upsample, padding_type=padding_type)
820
+ upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
821
+ down = [downrelu] + downconv
822
+ if downnorm is not None:
823
+ down += [downnorm]
824
+ up = [uprelu] + upconv
825
+ if upnorm is not None:
826
+ up += [upnorm]
827
+ up += [uprelu2, uppad, upconv2]
828
+ if upnorm2 is not None:
829
+ up += [upnorm2]
830
+
831
+ if use_dropout:
832
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
833
+ else:
834
+ model = down + [submodule] + up
835
+
836
+ self.model = nn.Sequential(*model)
837
+
838
+ def forward(self, x):
839
+ if self.outermost:
840
+ return self.model(x)
841
+ else:
842
+ x2 = self.model(x)
843
+ if self.noise:
844
+ x2 = self.noiseblock(x2, self.noise)
845
+ if x2.shape[-1]==x.shape[-1]:
846
+ return x2 + x
847
+ else:
848
+ x2 = F.interpolate(x2, x.shape[2:])
849
+ return x2 + x
850
+
851
+
852
+ class E_ResNet(nn.Module):
853
+ def __init__(self, input_nc=3, output_nc=1, ndf=64, n_blocks=4,
854
+ norm_layer=None, nl_layer=None, vaeLike=False):
855
+ super(E_ResNet, self).__init__()
856
+ self.vaeLike = vaeLike
857
+ max_ndf = 4
858
+ conv_layers = [
859
+ nn.Conv2d(input_nc, ndf, kernel_size=3, stride=2, padding=1, bias=True)]
860
+ for n in range(1, n_blocks):
861
+ input_ndf = ndf * min(max_ndf, n)
862
+ output_ndf = ndf * min(max_ndf, n + 1)
863
+ conv_layers += [BasicBlock(input_ndf,
864
+ output_ndf, norm_layer, nl_layer)]
865
+ conv_layers += [nl_layer(), nn.AdaptiveAvgPool2d(4)]
866
+ if vaeLike:
867
+ self.fc = nn.Sequential(*[nn.Linear(output_ndf * 16, output_nc)])
868
+ self.fcVar = nn.Sequential(*[nn.Linear(output_ndf * 16, output_nc)])
869
+ else:
870
+ self.fc = nn.Sequential(*[nn.Linear(output_ndf * 16, output_nc)])
871
+ self.conv = nn.Sequential(*conv_layers)
872
+
873
+ def forward(self, x):
874
+ x_conv = self.conv(x)
875
+ conv_flat = x_conv.view(x.size(0), -1)
876
+ output = self.fc(conv_flat)
877
+ if self.vaeLike:
878
+ outputVar = self.fcVar(conv_flat)
879
+ return output, outputVar
880
+ else:
881
+ return output
882
+ return output
883
+
884
+
885
+ # Defines the Unet generator.
886
+ # |num_downs|: number of downsamplings in UNet. For example,
887
+ # if |num_downs| == 7, image of size 128x128 will become of size 1x1
888
+ # at the bottleneck
889
+ class G_Unet_add_all(nn.Module):
890
+ def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64,
891
+ norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False, upsample='basic'):
892
+ super(G_Unet_add_all, self).__init__()
893
+ self.nz = nz
894
+ self.mapping = G_mapping(self.nz, self.nz, 512, normalize_latents=False, lrmul=1)
895
+ self.truncation_psi = 0
896
+ self.truncation_cutoff = 0
897
+
898
+ # - 2 means we start from feature map with height and width equals 4.
899
+ # as this example, we get num_layers = 18.
900
+ num_layers = int(np.log2(512)) * 2 - 2
901
+ # Noise inputs.
902
+ self.noise_inputs = []
903
+ for layer_idx in range(num_layers):
904
+ res = layer_idx // 2 + 2
905
+ shape = [1, 1, 2 ** res, 2 ** res]
906
+ self.noise_inputs.append(torch.randn(*shape).to("cuda" if torch.cuda.is_available() else "cpu"))
907
+
908
+ # construct unet structure
909
+ unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=None, innermost=True,
910
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
911
+ unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=unet_block,
912
+ norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
913
+ for i in range(num_downs - 6):
914
+ unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=unet_block,
915
+ norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
916
+ unet_block = UnetBlock_with_z(ngf * 4, ngf * 4, ngf * 8, nz, submodule=unet_block,
917
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
918
+ unet_block = UnetBlock_with_z(ngf * 2, ngf * 2, ngf * 4, nz, submodule=unet_block,
919
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
920
+ unet_block = UnetBlock_with_z(ngf, ngf, ngf * 2, nz, submodule=unet_block,
921
+ norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
922
+ unet_block = UnetBlock_with_z(input_nc, output_nc, ngf, nz, submodule=unet_block,
923
+ outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
924
+ self.model = unet_block
925
+
926
+ def forward(self, x, z):
927
+
928
+ dlatents1, num_layers = self.mapping(z)
929
+ dlatents1 = dlatents1.unsqueeze(1)
930
+ dlatents1 = dlatents1.expand(-1, int(num_layers), -1)
931
+
932
+ # Apply truncation trick.
933
+ if self.truncation_psi and self.truncation_cutoff:
934
+ coefs = np.ones([1, num_layers, 1], dtype=np.float32)
935
+ for i in range(num_layers):
936
+ if i < self.truncation_cutoff:
937
+ coefs[:, i, :] *= self.truncation_psi
938
+ """Linear interpolation.
939
+ a + (b - a) * t (a = 0)
940
+ reduce to
941
+ b * t
942
+ """
943
+ dlatents1 = dlatents1 * torch.Tensor(coefs).to(dlatents1.device)
944
+
945
+ return torch.tanh(self.model(x, dlatents1, self.noise_inputs))
946
+
947
+
948
+ class ApplyNoise(nn.Module):
949
+ def __init__(self, channels):
950
+ super().__init__()
951
+ self.channels = channels
952
+ self.weight = nn.Parameter(torch.randn(channels), requires_grad=True)
953
+ self.bias = nn.Parameter(torch.zeros(channels), requires_grad=True)
954
+
955
+ def forward(self, x, noise):
956
+ W,_ = torch.split(self.weight.view(1, -1, 1, 1), self.channels // 2, dim=1)
957
+ B,_ = torch.split(self.bias.view(1, -1, 1, 1), self.channels // 2, dim=1)
958
+ Z = torch.zeros_like(W)
959
+ w = torch.cat([W,Z], dim=1).to(x.device)
960
+ b = torch.cat([B,Z], dim=1).to(x.device)
961
+ adds = w * torch.randn_like(x) + b
962
+ return x + adds.type_as(x)
963
+
964
+
965
+ class FC(nn.Module):
966
+ def __init__(self,
967
+ in_channels,
968
+ out_channels,
969
+ gain=2**(0.5),
970
+ use_wscale=False,
971
+ lrmul=1.0,
972
+ bias=True):
973
+ """
974
+ The complete conversion of Dense/FC/Linear Layer of original Tensorflow version.
975
+ """
976
+ super(FC, self).__init__()
977
+ he_std = gain * in_channels ** (-0.5) # He init
978
+ if use_wscale:
979
+ init_std = 1.0 / lrmul
980
+ self.w_lrmul = he_std * lrmul
981
+ else:
982
+ init_std = he_std / lrmul
983
+ self.w_lrmul = lrmul
984
+
985
+ self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels) * init_std)
986
+ if bias:
987
+ self.bias = torch.nn.Parameter(torch.zeros(out_channels))
988
+ self.b_lrmul = lrmul
989
+ else:
990
+ self.bias = None
991
+
992
+ def forward(self, x):
993
+ if self.bias is not None:
994
+ out = F.linear(x, self.weight * self.w_lrmul, self.bias * self.b_lrmul)
995
+ else:
996
+ out = F.linear(x, self.weight * self.w_lrmul)
997
+ out = F.leaky_relu(out, 0.2, inplace=True)
998
+ return out
999
+
1000
+
1001
+ class ApplyStyle(nn.Module):
1002
  """
1003
+ @ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1004
  """
1005
+ def __init__(self, latent_size, channels, use_wscale, nl_layer):
1006
+ super(ApplyStyle, self).__init__()
1007
+ modules = [nn.Linear(latent_size, channels*2)]
1008
+ if nl_layer:
1009
+ modules += [nl_layer()]
1010
+ self.linear = nn.Sequential(*modules)
1011
+
1012
+ def forward(self, x, latent):
1013
+ style = self.linear(latent) # style => [batch_size, n_channels*2]
1014
+ shape = [-1, 2, x.size(1), 1, 1]
1015
+ style = style.view(shape) # [batch_size, 2, n_channels, ...]
1016
+ x = x * (style[:, 0] + 1.) + style[:, 1]
1017
+ return x
1018
+
1019
+ class PixelNorm(nn.Module):
1020
+ def __init__(self, epsilon=1e-8):
1021
+ """
1022
+ @notice: avoid in-place ops.
1023
+ https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
1024
+ """
1025
+ super(PixelNorm, self).__init__()
1026
+ self.epsilon = epsilon
1027
+
1028
+ def forward(self, x):
1029
+ tmp = torch.mul(x, x) # or x ** 2
1030
+ tmp1 = torch.rsqrt(torch.mean(tmp, dim=1, keepdim=True) + self.epsilon)
1031
+
1032
+ return x * tmp1
1033
+
1034
+
1035
+ class InstanceNorm(nn.Module):
1036
+ def __init__(self, epsilon=1e-8):
1037
+ """
1038
+ @notice: avoid in-place ops.
1039
+ https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
1040
+ """
1041
+ super(InstanceNorm, self).__init__()
1042
+ self.epsilon = epsilon
1043
+
1044
+ def forward(self, x):
1045
+ x = x - torch.mean(x, (2, 3), True)
1046
+ tmp = torch.mul(x, x) # or x ** 2
1047
+ tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
1048
+ return x * tmp
1049
+
1050
+
1051
+ class LayerEpilogue(nn.Module):
1052
+ def __init__(self, channels, dlatent_size, use_wscale, use_noise,
1053
+ use_pixel_norm, use_instance_norm, use_styles, nl_layer=None):
1054
+ super(LayerEpilogue, self).__init__()
1055
+ self.use_noise = use_noise
1056
+ if use_noise:
1057
+ self.noise = ApplyNoise(channels)
1058
+ self.act = nn.LeakyReLU(negative_slope=0.2)
1059
+
1060
+ if use_pixel_norm:
1061
+ self.pixel_norm = PixelNorm()
1062
+ else:
1063
+ self.pixel_norm = None
1064
+
1065
+ if use_instance_norm:
1066
+ self.instance_norm = InstanceNorm()
1067
+ else:
1068
+ self.instance_norm = None
1069
+
1070
+ if use_styles:
1071
+ self.style_mod = ApplyStyle(dlatent_size, channels, use_wscale=use_wscale, nl_layer=nl_layer)
1072
+ else:
1073
+ self.style_mod = None
1074
+
1075
+ def forward(self, x, noise, dlatents_in_slice=None):
1076
+ # if noise is not None:
1077
+ if self.use_noise:
1078
+ x = self.noise(x, noise)
1079
+ x = self.act(x)
1080
+ if self.pixel_norm is not None:
1081
+ x = self.pixel_norm(x)
1082
+ if self.instance_norm is not None:
1083
+ x = self.instance_norm(x)
1084
+ if self.style_mod is not None:
1085
+ x = self.style_mod(x, dlatents_in_slice)
1086
+
1087
+ return x
1088
+
1089
+ class G_mapping(nn.Module):
1090
+ def __init__(self,
1091
+ mapping_fmaps=512,
1092
+ dlatent_size=512,
1093
+ resolution=512,
1094
+ normalize_latents=True, # Normalize latent vectors (Z) before feeding them to the mapping layers?
1095
+ use_wscale=True, # Enable equalized learning rate?
1096
+ lrmul=0.01, # Learning rate multiplier for the mapping layers.
1097
+ gain=2**(0.5), # original gain in tensorflow.
1098
+ nl_layer=None
1099
+ ):
1100
+ super(G_mapping, self).__init__()
1101
+ self.mapping_fmaps = mapping_fmaps
1102
+ func = [
1103
+ nn.Linear(self.mapping_fmaps, dlatent_size)
1104
+ ]
1105
+ if nl_layer:
1106
+ func += [nl_layer()]
1107
+
1108
+ for j in range(0,4):
1109
+ func += [
1110
+ nn.Linear(dlatent_size, dlatent_size)
1111
+ ]
1112
+ if nl_layer:
1113
+ func += [nl_layer()]
1114
+
1115
+ self.func = nn.Sequential(*func)
1116
+ #FC(self.mapping_fmaps, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale),
1117
+ #FC(dlatent_size, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale),
1118
+
1119
+ self.normalize_latents = normalize_latents
1120
+ self.resolution_log2 = int(np.log2(resolution))
1121
+ self.num_layers = self.resolution_log2 * 2 - 2
1122
+ self.pixel_norm = PixelNorm()
1123
+ # - 2 means we start from feature map with height and width equals 4.
1124
+ # as this example, we get num_layers = 18.
1125
+
1126
+ def forward(self, x):
1127
+ if self.normalize_latents:
1128
+ x = self.pixel_norm(x)
1129
+ out = self.func(x)
1130
+ return out, self.num_layers
1131
+
1132
+ class UnetBlock_with_z(nn.Module):
1133
+ def __init__(self, input_nc, outer_nc, inner_nc, nz=0,
1134
+ submodule=None, outermost=False, innermost=False,
1135
+ norm_layer=None, nl_layer=None, use_dropout=False,
1136
+ upsample='basic', padding_type='replicate'):
1137
+ super(UnetBlock_with_z, self).__init__()
1138
+ p = 0
1139
+ downconv = []
1140
+ if padding_type == 'reflect':
1141
+ downconv += [nn.ReflectionPad2d(1)]
1142
+ elif padding_type == 'replicate':
1143
+ downconv += [nn.ReplicationPad2d(1)]
1144
+ elif padding_type == 'zero':
1145
+ p = 1
1146
+ else:
1147
+ raise NotImplementedError(
1148
+ 'padding [%s] is not implemented' % padding_type)
1149
+
1150
+ self.outermost = outermost
1151
+ self.innermost = innermost
1152
+ self.nz = nz
1153
+
1154
+ # input_nc = input_nc + nz
1155
+ downconv += [spectral_norm(nn.Conv2d(input_nc, inner_nc,
1156
+ kernel_size=3, stride=2, padding=p))]
1157
+ # downsample is different from upsample
1158
+ downrelu = nn.LeakyReLU(0.2, True)
1159
+ downnorm = norm_layer(inner_nc) if norm_layer is not None else None
1160
+ uprelu = nl_layer()
1161
+ uprelu2 = nl_layer()
1162
+ uppad = nn.ReplicationPad2d(1)
1163
+ upnorm = norm_layer(outer_nc) if norm_layer is not None else None
1164
+ upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None
1165
+
1166
+ use_styles=False
1167
+ uprelu = nl_layer()
1168
+ if self.nz >0:
1169
+ use_styles=True
1170
+
1171
+ if outermost:
1172
+ self.adaIn = LayerEpilogue(inner_nc, self.nz, use_wscale=True, use_noise=False,
1173
+ use_pixel_norm=True, use_instance_norm=True, use_styles=use_styles, nl_layer=nl_layer)
1174
+ upconv = upsampleLayer(
1175
+ inner_nc , outer_nc, upsample=upsample, padding_type=padding_type)
1176
+ upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
1177
+ down = downconv
1178
+ up = [uprelu] + upconv
1179
+ if upnorm is not None:
1180
+ up += [upnorm]
1181
+ up +=[uprelu2, uppad, upconv2] #+ [nn.Tanh()]
1182
+ elif innermost:
1183
+ self.adaIn = LayerEpilogue(inner_nc, self.nz, use_wscale=True, use_noise=True,
1184
+ use_pixel_norm=True, use_instance_norm=True, use_styles=use_styles, nl_layer=nl_layer)
1185
+ upconv = upsampleLayer(
1186
+ inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
1187
+ upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
1188
+ down = [downrelu] + downconv
1189
+ up = [uprelu] + upconv
1190
+ if norm_layer is not None:
1191
+ up += [norm_layer(outer_nc)]
1192
+ up += [uprelu2, uppad, upconv2]
1193
+ if upnorm2 is not None:
1194
+ up += [upnorm2]
1195
+ else:
1196
+ self.adaIn = LayerEpilogue(inner_nc, self.nz, use_wscale=True, use_noise=False,
1197
+ use_pixel_norm=True, use_instance_norm=True, use_styles=use_styles, nl_layer=nl_layer)
1198
+ upconv = upsampleLayer(
1199
+ inner_nc , outer_nc, upsample=upsample, padding_type=padding_type)
1200
+ upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
1201
+ down = [downrelu] + downconv
1202
+ if norm_layer is not None:
1203
+ down += [norm_layer(inner_nc)]
1204
+ up = [uprelu] + upconv
1205
+
1206
+ if norm_layer is not None:
1207
+ up += [norm_layer(outer_nc)]
1208
+ up += [uprelu2, uppad, upconv2]
1209
+ if upnorm2 is not None:
1210
+ up += [upnorm2]
1211
+
1212
+ if use_dropout:
1213
+ up += [nn.Dropout(0.5)]
1214
+ self.down = nn.Sequential(*down)
1215
+ self.submodule = submodule
1216
+ self.up = nn.Sequential(*up)
1217
+
1218
+
1219
+ def forward(self, x, z, noise):
1220
+ if self.outermost:
1221
+ x1 = self.down(x)
1222
+ x2 = self.submodule(x1, z[:,2:], noise[2:])
1223
+ return self.up(x2)
1224
+
1225
+ elif self.innermost:
1226
+ x1 = self.down(x)
1227
+ x_and_z = self.adaIn(x1, noise[0], z[:,0])
1228
+ x2 = self.up(x_and_z)
1229
+ x2 = F.interpolate(x2, x.shape[2:])
1230
+ return x2 + x
1231
+
1232
+ else:
1233
+ x1 = self.down(x)
1234
+ x2 = self.submodule(x1, z[:,2:], noise[2:])
1235
+ x_and_z = self.adaIn(x2, noise[0], z[:,0])
1236
+ return self.up(x_and_z) + x
1237
+
1238
+
1239
+ class E_NLayers(nn.Module):
1240
+ def __init__(self, input_nc, output_nc=1, ndf=64, n_layers=4,
1241
+ norm_layer=None, nl_layer=None, vaeLike=False):
1242
+ super(E_NLayers, self).__init__()
1243
+ self.vaeLike = vaeLike
1244
+
1245
+ kw, padw = 3, 1
1246
+ sequence = [spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw,
1247
+ stride=2, padding=padw, padding_mode='replicate')), nl_layer()]
1248
+
1249
+ nf_mult = 1
1250
+ nf_mult_prev = 1
1251
+ for n in range(1, n_layers):
1252
+ nf_mult_prev = nf_mult
1253
+ nf_mult = min(2**n, 8)
1254
+ sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
1255
+ kernel_size=kw, stride=2, padding=padw, padding_mode='replicate'))]
1256
+ if norm_layer is not None:
1257
+ sequence += [norm_layer(ndf * nf_mult)]
1258
+ sequence += [nl_layer()]
1259
+ sequence += [nn.AdaptiveAvgPool2d(4)]
1260
+ self.conv = nn.Sequential(*sequence)
1261
+ self.fc = nn.Sequential(*[spectral_norm(nn.Linear(ndf * nf_mult * 16, output_nc))])
1262
+ if vaeLike:
1263
+ self.fcVar = nn.Sequential(*[spectral_norm(nn.Linear(ndf * nf_mult * 16, output_nc))])
1264
+
1265
+ def forward(self, x):
1266
+ x_conv = self.conv(x)
1267
+ conv_flat = x_conv.view(x.size(0), -1)
1268
+ output = self.fc(conv_flat)
1269
+ if self.vaeLike:
1270
+ outputVar = self.fcVar(conv_flat)
1271
+ return output, outputVar
1272
+ return output
1273
+
1274
+ class BasicBlock(nn.Module):
1275
+ def __init__(self, inplanes, outplanes):
1276
+ super(BasicBlock, self).__init__()
1277
+ layers = []
1278
+ norm_layer=get_norm_layer(norm_type='layer') #functools.partial(LayerNorm)
1279
+ # norm_layer = None
1280
+ nl_layer=nn.ReLU()
1281
+ if norm_layer is not None:
1282
+ layers += [norm_layer(inplanes)]
1283
+ layers += [nl_layer]
1284
+ layers += [nn.ReplicationPad2d(1),
1285
+ nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=1,
1286
+ padding=0, bias=True)]
1287
+ self.conv = nn.Sequential(*layers)
1288
+
1289
+ def forward(self, x):
1290
+ return self.conv(x)
1291
+
1292
+
1293
+ def define_SVAE(inc=96, outc=3, outplanes=64, blocks=1, netVAE='SVAE', model_name='', load_ext=True, save_dir='',
1294
+ init_type="normal", init_gain=0.02, gpu_ids=[]):
1295
+ if netVAE == 'SVAE':
1296
+ net = ScreenVAE(inc=inc, outc=outc, outplanes=outplanes, blocks=blocks, save_dir=save_dir,
1297
+ init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids)
1298
+ else:
1299
+ raise NotImplementedError('Encoder model name [%s] is not recognized' % net)
1300
+ init_net(net, init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids)
1301
+ net.load_networks('latest')
1302
+ return net
1303
+
1304
+
1305
+ class ScreenVAE(nn.Module):
1306
+ def __init__(self,inc=1,outc=4, outplanes=64, downs=5, blocks=2,load_ext=True, save_dir='',init_type="normal", init_gain=0.02, gpu_ids=[]):
1307
+ super(ScreenVAE, self).__init__()
1308
+ self.inc = inc
1309
+ self.outc = outc
1310
+ self.save_dir = save_dir
1311
+ norm_layer=functools.partial(LayerNormWarpper)
1312
+ nl_layer=nn.LeakyReLU
1313
+
1314
+ self.model_names=['enc','dec']
1315
+ self.enc=define_C(inc+1, outc*2, 0, 24, netC='resnet_6blocks',
1316
+ norm='layer', nl='lrelu', use_dropout=True, init_type='kaiming',
1317
+ gpu_ids=gpu_ids, upsample='bilinear')
1318
+ self.dec=define_G(outc, inc, 0, 48, netG='unet_128_G',
1319
+ norm='layer', nl='lrelu', use_dropout=True, init_type='kaiming',
1320
+ gpu_ids=gpu_ids, where_add='input', upsample='bilinear', use_noise=True)
1321
+
1322
+ for param in self.parameters():
1323
+ param.requires_grad = False
1324
+
1325
+ def load_networks(self, epoch):
1326
+ """Load all the networks from the disk.
1327
+
1328
+ Parameters:
1329
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
1330
+ """
1331
+ for name in self.model_names:
1332
+ if isinstance(name, str):
1333
+ load_filename = '%s_net_%s.pth' % (epoch, name)
1334
+ load_path = os.path.join(self.save_dir, load_filename)
1335
+ net = getattr(self, name)
1336
+ if isinstance(net, torch.nn.DataParallel):
1337
+ net = net.module
1338
+ print('loading the model from %s' % load_path)
1339
+ state_dict = torch.load(
1340
+ load_path, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
1341
+ if hasattr(state_dict, '_metadata'):
1342
+ del state_dict._metadata
1343
+
1344
+ net.load_state_dict(state_dict)
1345
+ del state_dict
1346
+
1347
+ def npad(self, im, pad=128):
1348
+ h,w = im.shape[-2:]
1349
+ hp = h //pad*pad+pad
1350
+ wp = w //pad*pad+pad
1351
+ return F.pad(im, (0, wp-w, 0, hp-h), mode='replicate')
1352
+
1353
+ def forward(self, x, line=None, img_input=True, output_screen_only=True):
1354
+ if img_input:
1355
+ if line is None:
1356
+ line = torch.ones_like(x)
1357
+ else:
1358
+ line = torch.sign(line)
1359
+ x = torch.clamp(x + (1-line),-1,1)
1360
+ h,w = x.shape[-2:]
1361
+ input = torch.cat([x, line], 1)
1362
+ input = self.npad(input)
1363
+ inter = self.enc(input)[:,:,:h,:w]
1364
+ scr, logvar = torch.split(inter, (self.outc, self.outc), dim=1)
1365
+ if output_screen_only:
1366
+ return scr
1367
+ recons = self.dec(scr)
1368
+ return recons, scr, logvar
1369
+ else:
1370
+ h,w = x.shape[-2:]
1371
+ x = self.npad(x)
1372
+ recons = self.dec(x)[:,:,:h,:w]
1373
+ recons = (recons+1)*(line+1)/2-1
1374
+ return torch.clamp(recons,-1,1)