facehuggingjay commited on
Commit
3fb5067
·
verified ·
1 Parent(s): 85c91be
Files changed (1) hide show
  1. app.py +574 -1368
app.py CHANGED
@@ -1,1374 +1,580 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch.nn import init
4
- import functools
5
- from torch.optim import lr_scheduler
 
 
 
 
 
 
 
 
 
 
6
  import numpy as np
 
 
7
  import torch.nn.functional as F
8
- from torch.nn.modules.normalization import LayerNorm
9
- import os
10
- from torch.nn.utils import spectral_norm
11
- from torchvision import models
12
-
13
- ###############################################################################
14
- # Helper functions
15
- ###############################################################################
16
-
17
-
18
- def init_weights(net, init_type='normal', init_gain=0.02):
19
- """Initialize network weights.
20
- Parameters:
21
- net (network) -- network to be initialized
22
- init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
23
- init_gain (float) -- scaling factor for normal, xavier and orthogonal.
24
- We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
25
- work better for some applications. Feel free to try yourself.
26
- """
27
- def init_func(m): # define the initialization function
28
- classname = m.__class__.__name__
29
- if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
30
- if init_type == 'normal':
31
- init.normal_(m.weight.data, 0.0, init_gain)
32
- elif init_type == 'xavier':
33
- init.xavier_normal_(m.weight.data, gain=init_gain)
34
- elif init_type == 'kaiming':
35
- #init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
36
- init.kaiming_normal_(m.weight.data, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
37
- elif init_type == 'orthogonal':
38
- init.orthogonal_(m.weight.data, gain=init_gain)
39
- else:
40
- raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
41
- if hasattr(m, 'bias') and m.bias is not None:
42
- init.constant_(m.bias.data, 0.0)
43
- elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
44
- init.normal_(m.weight.data, 1.0, init_gain)
45
- init.constant_(m.bias.data, 0.0)
46
-
47
- print('initialize network with %s' % init_type)
48
- net.apply(init_func) # apply the initialization function <init_func>
49
-
50
-
51
- def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], init=True):
52
- """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
53
- Parameters:
54
- net (network) -- the network to be initialized
55
- init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
56
- gain (float) -- scaling factor for normal, xavier and orthogonal.
57
- gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
58
- Return an initialized network.
59
- """
60
- if len(gpu_ids) > 0 and torch.cuda.is_available():
61
- net.to(gpu_ids[0])
62
- if init:
63
- init_weights(net, init_type, init_gain=init_gain)
64
- return net
65
-
66
-
67
- def get_scheduler(optimizer, opt):
68
- """Return a learning rate scheduler
69
- Parameters:
70
- optimizer -- the optimizer of the network
71
- opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 
72
- opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
73
- For 'linear', we keep the same learning rate for the first <opt.niter> epochs
74
- and linearly decay the rate to zero over the next <opt.niter_decay> epochs.
75
- For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
76
- See https://pytorch.org/docs/stable/optim.html for more details.
77
- """
78
- if opt.lr_policy == 'linear':
79
- def lambda_rule(epoch):
80
- lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
81
- return lr_l
82
- scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
83
- elif opt.lr_policy == 'step':
84
- scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
85
- elif opt.lr_policy == 'plateau':
86
- scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
87
- elif opt.lr_policy == 'cosine':
88
- scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
89
- else:
90
- return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
91
- return scheduler
92
-
93
- class LayerNormWarpper(nn.Module):
94
- def __init__(self, num_features):
95
- super(LayerNormWarpper, self).__init__()
96
- self.num_features = int(num_features)
97
-
98
- def forward(self, x):
99
- x = nn.LayerNorm([self.num_features, x.size()[2], x.size()[3]], elementwise_affine=False).to(x.device)(x)
100
- return x
101
-
102
- def get_norm_layer(norm_type='instance'):
103
- """Return a normalization layer
104
- Parameters:
105
- norm_type (str) -- the name of the normalization layer: batch | instance | none
106
- For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
107
- For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
108
- """
109
- if norm_type == 'batch':
110
- norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
111
- elif norm_type == 'instance':
112
- norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
113
- elif norm_type == 'layer':
114
- norm_layer = functools.partial(LayerNormWarpper)
115
- elif norm_type == 'none':
116
- norm_layer = None
117
- else:
118
- raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
119
- return norm_layer
120
-
121
-
122
- def get_non_linearity(layer_type='relu'):
123
- if layer_type == 'relu':
124
- nl_layer = functools.partial(nn.ReLU, inplace=True)
125
- elif layer_type == 'lrelu':
126
- nl_layer = functools.partial(
127
- nn.LeakyReLU, negative_slope=0.2, inplace=True)
128
- elif layer_type == 'elu':
129
- nl_layer = functools.partial(nn.ELU, inplace=True)
130
- elif layer_type == 'selu':
131
- nl_layer = functools.partial(nn.SELU, inplace=True)
132
- elif layer_type == 'prelu':
133
- nl_layer = functools.partial(nn.PReLU)
134
- else:
135
- raise NotImplementedError(
136
- 'nonlinearity activitation [%s] is not found' % layer_type)
137
- return nl_layer
138
-
139
-
140
- def define_G(input_nc, output_nc, nz, ngf, netG='unet_128', norm='batch', nl='relu', use_noise=False,
141
- use_dropout=False, init_type='xavier', init_gain=0.02, gpu_ids=[], where_add='input', upsample='bilinear'):
142
- net = None
143
- norm_layer = get_norm_layer(norm_type=norm)
144
- nl_layer = get_non_linearity(layer_type=nl)
145
- # print(norm, norm_layer)
146
-
147
- if nz == 0:
148
- where_add = 'input'
149
-
150
- if netG == 'unet_128' and where_add == 'input':
151
- net = G_Unet_add_input(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
152
- use_dropout=use_dropout, upsample=upsample, device=gpu_ids)
153
- elif netG == 'unet_128_G' and where_add == 'input':
154
- net = G_Unet_add_input_G(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
155
- use_dropout=use_dropout, upsample=upsample, device=gpu_ids)
156
- elif netG == 'unet_256' and where_add == 'input':
157
- net = G_Unet_add_input(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
158
- use_dropout=use_dropout, upsample=upsample, device=gpu_ids)
159
- elif netG == 'unet_256_G' and where_add == 'input':
160
- net = G_Unet_add_input_G(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
161
- use_dropout=use_dropout, upsample=upsample, device=gpu_ids)
162
- elif netG == 'unet_128' and where_add == 'all':
163
- net = G_Unet_add_all(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
164
- use_dropout=use_dropout, upsample=upsample)
165
- elif netG == 'unet_256' and where_add == 'all':
166
- net = G_Unet_add_all(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
167
- use_dropout=use_dropout, upsample=upsample)
168
- else:
169
- raise NotImplementedError('Generator model name [%s] is not recognized' % net)
170
- # print(net)
171
- return init_net(net, init_type, init_gain, gpu_ids)
172
-
173
-
174
- def define_C(input_nc, output_nc, nz, ngf, netC='unet_128', norm='instance', nl='relu',
175
- use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], upsample='basic'):
176
- net = None
177
- norm_layer = get_norm_layer(norm_type=norm)
178
- nl_layer = get_non_linearity(layer_type=nl)
179
-
180
- if netC == 'resnet_9blocks':
181
- net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
182
- elif netC == 'resnet_6blocks':
183
- net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
184
- elif netC == 'unet_128':
185
- net = G_Unet_add_input_C(input_nc, output_nc, 0, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
186
- use_dropout=use_dropout, upsample=upsample)
187
- elif netC == 'unet_256':
188
- net = G_Unet_add_input(input_nc, output_nc, 0, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
189
- use_dropout=use_dropout, upsample=upsample)
190
- elif netC == 'unet_32':
191
- net = G_Unet_add_input(input_nc, output_nc, 0, 5, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
192
- use_dropout=use_dropout, upsample=upsample)
193
- else:
194
- raise NotImplementedError('Generator model name [%s] is not recognized' % net)
195
-
196
- return init_net(net, init_type, init_gain, gpu_ids)
197
-
198
-
199
- def define_D(input_nc, ndf, netD, norm='batch', nl='lrelu', init_type='xavier', init_gain=0.02, num_Ds=1, gpu_ids=[]):
200
- net = None
201
- norm_layer = get_norm_layer(norm_type=norm)
202
- nl = 'lrelu' # use leaky relu for D
203
- nl_layer = get_non_linearity(layer_type=nl)
204
-
205
- if netD == 'basic_128':
206
- net = D_NLayers(input_nc, ndf, n_layers=2, norm_layer=norm_layer, nl_layer=nl_layer)
207
- elif netD == 'basic_256':
208
- net = D_NLayers(input_nc, ndf, n_layers=3, norm_layer=norm_layer, nl_layer=nl_layer)
209
- elif netD == 'basic_128_multi':
210
- net = D_NLayersMulti(input_nc=input_nc, ndf=ndf, n_layers=2, norm_layer=norm_layer, num_D=num_Ds, nl_layer=nl_layer)
211
- elif netD == 'basic_256_multi':
212
- net = D_NLayersMulti(input_nc=input_nc, ndf=ndf, n_layers=3, norm_layer=norm_layer, num_D=num_Ds, nl_layer=nl_layer)
213
- else:
214
- raise NotImplementedError('Discriminator model name [%s] is not recognized' % net)
215
- return init_net(net, init_type, init_gain, gpu_ids)
216
-
217
-
218
- def define_E(input_nc, output_nc, ndf, netE, norm='batch', nl='lrelu',
219
- init_type='xavier', init_gain=0.02, gpu_ids=[], vaeLike=False):
220
- net = None
221
- norm_layer = get_norm_layer(norm_type=norm)
222
- nl = 'lrelu' # use leaky relu for E
223
- nl_layer = get_non_linearity(layer_type=nl)
224
- if netE == 'resnet_128':
225
- net = E_ResNet(input_nc, output_nc, ndf, n_blocks=4, norm_layer=norm_layer,
226
- nl_layer=nl_layer, vaeLike=vaeLike)
227
- elif netE == 'resnet_256':
228
- net = E_ResNet(input_nc, output_nc, ndf, n_blocks=5, norm_layer=norm_layer,
229
- nl_layer=nl_layer, vaeLike=vaeLike)
230
- elif netE == 'conv_128':
231
- net = E_NLayers(input_nc, output_nc, ndf, n_layers=4, norm_layer=norm_layer,
232
- nl_layer=nl_layer, vaeLike=vaeLike)
233
- elif netE == 'conv_256':
234
- net = E_NLayers(input_nc, output_nc, ndf, n_layers=5, norm_layer=norm_layer,
235
- nl_layer=nl_layer, vaeLike=vaeLike)
236
- else:
237
- raise NotImplementedError('Encoder model name [%s] is not recognized' % net)
238
-
239
- return init_net(net, init_type, init_gain, gpu_ids, False)
240
-
241
-
242
- class ResnetGenerator(nn.Module):
243
- def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, norm_layer=None, use_dropout=False, n_blocks=6, padding_type='replicate'):
244
- assert(n_blocks >= 0)
245
- super(ResnetGenerator, self).__init__()
246
- self.input_nc = input_nc
247
- self.output_nc = output_nc
248
- self.ngf = ngf
249
- if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
250
- use_bias = norm_layer.func != nn.BatchNorm2d
251
  else:
252
- use_bias = norm_layer != nn.BatchNorm2d
253
-
254
- model = [nn.ReplicationPad2d(3),
255
- nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
256
- bias=use_bias)]
257
- if norm_layer is not None:
258
- model += [norm_layer(ngf)]
259
- model += [nn.ReLU(True)]
260
-
261
- # n_downsampling = 2
262
- for i in range(n_downsampling):
263
- mult = 2**i
264
- model += [nn.ReplicationPad2d(1),nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
265
- stride=2, padding=0, bias=use_bias)]
266
- # model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
267
- # stride=2, padding=1, bias=use_bias)]
268
- if norm_layer is not None:
269
- model += [norm_layer(ngf * mult * 2)]
270
- model += [nn.ReLU(True)]
271
-
272
- mult = 2**n_downsampling
273
- for i in range(n_blocks):
274
- model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
275
-
276
- for i in range(n_downsampling):
277
- mult = 2**(n_downsampling - i)
278
- # model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
279
- # kernel_size=3, stride=2,
280
- # padding=1, output_padding=1,
281
- # bias=use_bias)]
282
- # if norm_layer is not None:
283
- # model += [norm_layer(ngf * mult / 2)]
284
- # model += [nn.ReLU(True)]
285
- model += upsampleLayer(ngf * mult, int(ngf * mult / 2), upsample='bilinear', padding_type=padding_type)
286
- if norm_layer is not None:
287
- model += [norm_layer(int(ngf * mult / 2))]
288
- model += [nn.ReLU(True)]
289
- model +=[nn.ReplicationPad2d(1),
290
- nn.Conv2d(int(ngf * mult / 2), int(ngf * mult / 2), kernel_size=3, padding=0)]
291
- if norm_layer is not None:
292
- model += [norm_layer(ngf * mult / 2)]
293
- model += [nn.ReLU(True)]
294
- model += [nn.ReplicationPad2d(3)]
295
- model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
296
- #model += [nn.Tanh()]
297
-
298
- self.model = nn.Sequential(*model)
299
-
300
- def forward(self, input):
301
- return self.model(input)
302
-
303
-
304
- # Define a resnet block
305
- class ResnetBlock(nn.Module):
306
- def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
307
- super(ResnetBlock, self).__init__()
308
- self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
309
-
310
- def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
311
- conv_block = []
312
- p = 0
313
- if padding_type == 'reflect':
314
- conv_block += [nn.ReflectionPad2d(1)]
315
- elif padding_type == 'replicate':
316
- conv_block += [nn.ReplicationPad2d(1)]
317
- elif padding_type == 'zero':
318
- p = 1
319
- else:
320
- raise NotImplementedError('padding [%s] is not implemented' % padding_type)
321
-
322
- conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)]
323
- if norm_layer is not None:
324
- conv_block += [norm_layer(dim)]
325
- conv_block += [nn.ReLU(True)]
326
- # if use_dropout:
327
- # conv_block += [nn.Dropout(0.5)]
328
-
329
- p = 0
330
- if padding_type == 'reflect':
331
- conv_block += [nn.ReflectionPad2d(1)]
332
- elif padding_type == 'replicate':
333
- conv_block += [nn.ReplicationPad2d(1)]
334
- elif padding_type == 'zero':
335
- p = 1
336
- else:
337
- raise NotImplementedError('padding [%s] is not implemented' % padding_type)
338
- conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)]
339
- if norm_layer is not None:
340
- conv_block += [norm_layer(dim)]
341
-
342
- return nn.Sequential(*conv_block)
343
-
344
- def forward(self, x):
345
- out = x + self.conv_block(x)
346
- return out
347
-
348
-
349
- class D_NLayersMulti(nn.Module):
350
- def __init__(self, input_nc, ndf=64, n_layers=3,
351
- norm_layer=nn.BatchNorm2d, num_D=1, nl_layer=None):
352
- super(D_NLayersMulti, self).__init__()
353
- # st()
354
- self.num_D = num_D
355
- self.nl_layer=nl_layer
356
- if num_D == 1:
357
- layers = self.get_layers(input_nc, ndf, n_layers, norm_layer)
358
- self.model = nn.Sequential(*layers)
359
- else:
360
- layers = self.get_layers(input_nc, ndf, n_layers, norm_layer)
361
- self.add_module("model_0", nn.Sequential(*layers))
362
- self.down = nn.functional.interpolate
363
- for i in range(1, num_D):
364
- ndf_i = int(round(ndf / (2**i)))
365
- layers = self.get_layers(input_nc, ndf_i, n_layers, norm_layer)
366
- self.add_module("model_%d" % i, nn.Sequential(*layers))
367
-
368
- def get_layers(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
369
- kw = 3
370
- padw = 1
371
- sequence = [spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw,
372
- stride=2, padding=padw)), nn.LeakyReLU(0.2, True)]
373
-
374
- nf_mult = 1
375
- nf_mult_prev = 1
376
- for n in range(1, n_layers):
377
- nf_mult_prev = nf_mult
378
- nf_mult = min(2**n, 8)
379
- sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
380
- kernel_size=kw, stride=2, padding=padw))]
381
- if norm_layer:
382
- sequence += [norm_layer(ndf * nf_mult)]
383
-
384
- sequence += [self.nl_layer()]
385
-
386
- nf_mult_prev = nf_mult
387
- nf_mult = min(2**n_layers, 8)
388
- sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
389
- kernel_size=kw, stride=1, padding=padw))]
390
- if norm_layer:
391
- sequence += [norm_layer(ndf * nf_mult)]
392
- sequence += [self.nl_layer()]
393
-
394
- sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult, 1,
395
- kernel_size=kw, stride=1, padding=padw))]
396
-
397
- return sequence
398
-
399
- def forward(self, input):
400
- if self.num_D == 1:
401
- return self.model(input)
402
- result = []
403
- down = input
404
- for i in range(self.num_D):
405
- model = getattr(self, "model_%d" % i)
406
- result.append(model(down))
407
- if i != self.num_D - 1:
408
- down = self.down(down, scale_factor=0.5, mode='bilinear')
409
- return result
410
-
411
- class D_NLayers(nn.Module):
412
- """Defines a PatchGAN discriminator"""
413
-
414
- def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
415
- """Construct a PatchGAN discriminator
416
- Parameters:
417
- input_nc (int) -- the number of channels in input images
418
- ndf (int) -- the number of filters in the last conv layer
419
- n_layers (int) -- the number of conv layers in the discriminator
420
- norm_layer -- normalization layer
421
- """
422
- super(D_NLayers, self).__init__()
423
- if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
424
- use_bias = norm_layer.func != nn.BatchNorm2d
425
- else:
426
- use_bias = norm_layer != nn.BatchNorm2d
427
-
428
- kw = 3
429
- padw = 1
430
- sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
431
- nf_mult = 1
432
- nf_mult_prev = 1
433
- for n in range(1, n_layers): # gradually increase the number of filters
434
- nf_mult_prev = nf_mult
435
- nf_mult = min(2 ** n, 8)
436
- sequence += [
437
- nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
438
- norm_layer(ndf * nf_mult),
439
- nn.LeakyReLU(0.2, True)
440
- ]
441
-
442
- nf_mult_prev = nf_mult
443
- nf_mult = min(2 ** n_layers, 8)
444
- sequence += [
445
- nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
446
- norm_layer(ndf * nf_mult),
447
- nn.LeakyReLU(0.2, True)
448
- ]
449
-
450
- sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
451
- self.model = nn.Sequential(*sequence)
452
-
453
- def forward(self, input):
454
- """Standard forward."""
455
- return self.model(input)
456
-
457
-
458
- class G_Unet_add_input(nn.Module):
459
- def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64,
460
- norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False,
461
- upsample='basic', device=0):
462
- super(G_Unet_add_input, self).__init__()
463
- self.nz = nz
464
- max_nchn = 8
465
- noise = []
466
- for i in range(num_downs+1):
467
- if use_noise:
468
- noise.append(True)
469
- else:
470
- noise.append(False)
471
-
472
- # construct unet structure
473
- #print(num_downs)
474
- unet_block = UnetBlock_A(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, noise=noise[num_downs-1],
475
- innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
476
- for i in range(num_downs - 5):
477
- unet_block = UnetBlock_A(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block, noise[num_downs-i-3],
478
- norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
479
- unet_block = UnetBlock_A(ngf * 4, ngf * 4, ngf * max_nchn, unet_block, noise[2],
480
- norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
481
- unet_block = UnetBlock_A(ngf * 2, ngf * 2, ngf * 4, unet_block, noise[1],
482
- norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
483
- unet_block = UnetBlock_A(ngf, ngf, ngf * 2, unet_block, noise[0],
484
- norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
485
- unet_block = UnetBlock_A(input_nc + nz, output_nc, ngf, unet_block, None,
486
- outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
487
-
488
- self.model = unet_block
489
-
490
- def forward(self, x, z=None):
491
- if self.nz > 0:
492
- z_img = z.view(z.size(0), z.size(1), 1, 1).expand(
493
- z.size(0), z.size(1), x.size(2), x.size(3))
494
- x_with_z = torch.cat([x, z_img], 1)
495
- else:
496
- x_with_z = x # no z
497
-
498
-
499
- return torch.tanh(self.model(x_with_z))
500
- # return self.model(x_with_z)
501
-
502
- class G_Unet_add_input_G(nn.Module):
503
- def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64,
504
- norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False,
505
- upsample='basic', device=0):
506
- super(G_Unet_add_input_G, self).__init__()
507
- self.nz = nz
508
- max_nchn = 8
509
- noise = []
510
- for i in range(num_downs+1):
511
- if use_noise:
512
- noise.append(True)
513
- else:
514
- noise.append(False)
515
- # construct unet structure
516
- #print(num_downs)
517
- unet_block = UnetBlock_G(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, noise=False,
518
- innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
519
- for i in range(num_downs - 5):
520
- unet_block = UnetBlock_G(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block, noise=False,
521
- norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
522
- unet_block = UnetBlock_G(ngf * 4, ngf * 4, ngf * max_nchn, unet_block, noise[2],
523
- norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic')
524
- unet_block = UnetBlock_G(ngf * 2, ngf * 2, ngf * 4, unet_block, noise[1],
525
- norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic')
526
- unet_block = UnetBlock_G(ngf, ngf, ngf * 2, unet_block, noise[0],
527
- norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic')
528
- unet_block = UnetBlock_G(input_nc + nz, output_nc, ngf, unet_block, None,
529
- outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic')
530
-
531
- self.model = unet_block
532
-
533
- def forward(self, x, z=None):
534
- if self.nz > 0:
535
- z_img = z.view(z.size(0), z.size(1), 1, 1).expand(
536
- z.size(0), z.size(1), x.size(2), x.size(3))
537
- x_with_z = torch.cat([x, z_img], 1)
538
- else:
539
- x_with_z = x # no z
540
-
541
- # return F.tanh(self.model(x_with_z))
542
- return self.model(x_with_z)
543
-
544
- class G_Unet_add_input_C(nn.Module):
545
- def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64,
546
- norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False,
547
- upsample='basic', device=0):
548
- super(G_Unet_add_input_C, self).__init__()
549
- self.nz = nz
550
- max_nchn = 8
551
- # construct unet structure
552
- #print(num_downs)
553
- unet_block = UnetBlock(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, noise=False,
554
- innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
555
- for i in range(num_downs - 5):
556
- unet_block = UnetBlock(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block, noise=False,
557
- norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
558
- unet_block = UnetBlock(ngf * 4, ngf * 4, ngf * max_nchn, unet_block, noise=False,
559
- norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
560
- unet_block = UnetBlock(ngf * 2, ngf * 2, ngf * 4, unet_block, noise=False,
561
- norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
562
- unet_block = UnetBlock(ngf, ngf, ngf * 2, unet_block, noise=False,
563
- norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
564
- unet_block = UnetBlock(input_nc + nz, output_nc, ngf, unet_block, noise=False,
565
- outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
566
-
567
- self.model = unet_block
568
-
569
- def forward(self, x, z=None):
570
- if self.nz > 0:
571
- z_img = z.view(z.size(0), z.size(1), 1, 1).expand(
572
- z.size(0), z.size(1), x.size(2), x.size(3))
573
- x_with_z = torch.cat([x, z_img], 1)
574
- else:
575
- x_with_z = x # no z
576
-
577
- # return torch.tanh(self.model(x_with_z))
578
- return self.model(x_with_z)
579
-
580
- def upsampleLayer(inplanes, outplanes, kw=1, upsample='basic', padding_type='replicate'):
581
- # padding_type = 'zero'
582
- if upsample == 'basic':
583
- upconv = [nn.ConvTranspose2d(inplanes, outplanes, kernel_size=4, stride=2, padding=1)]#, padding_mode='replicate'
584
- elif upsample == 'bilinear' or upsample == 'nearest' or upsample == 'linear':
585
- upconv = [nn.Upsample(scale_factor=2, mode=upsample, align_corners=True),
586
- #nn.ReplicationPad2d(1),
587
- nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0)]
588
- # p = kw//2
589
- # upconv = [nn.Upsample(scale_factor=2, mode=upsample, align_corners=True),
590
- # nn.Conv2d(inplanes, outplanes, kernel_size=kw, stride=1, padding=p, padding_mode='replicate')]
591
  else:
592
- raise NotImplementedError(
593
- 'upsample layer [%s] not implemented' % upsample)
594
- return upconv
595
-
596
- class UnetBlock_G(nn.Module):
597
- def __init__(self, input_nc, outer_nc, inner_nc,
598
- submodule=None, noise=None, outermost=False, innermost=False,
599
- norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='replicate'):
600
- super(UnetBlock_G, self).__init__()
601
- self.outermost = outermost
602
- p = 0
603
- downconv = []
604
- if padding_type == 'reflect':
605
- downconv += [nn.ReflectionPad2d(1)]
606
- elif padding_type == 'replicate':
607
- downconv += [nn.ReplicationPad2d(1)]
608
- elif padding_type == 'zero':
609
- p = 1
610
- else:
611
- raise NotImplementedError(
612
- 'padding [%s] is not implemented' % padding_type)
613
-
614
- downconv += [nn.Conv2d(input_nc, inner_nc,
615
- kernel_size=3, stride=2, padding=p)]
616
- # downsample is different from upsample
617
- downrelu = nn.LeakyReLU(0.2, True)
618
- downnorm = norm_layer(inner_nc) if norm_layer is not None else None
619
- uprelu = nl_layer()
620
- uprelu2 = nl_layer()
621
- uppad = nn.ReplicationPad2d(1)
622
- upnorm = norm_layer(outer_nc) if norm_layer is not None else None
623
- upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None
624
- self.noiseblock = ApplyNoise(outer_nc)
625
- self.noise = noise
626
-
627
- if outermost:
628
- upconv = upsampleLayer(inner_nc * 2, inner_nc, upsample=upsample, padding_type=padding_type)
629
- uppad = nn.ReplicationPad2d(3)
630
- upconv2 = nn.Conv2d(inner_nc, outer_nc, kernel_size=7, padding=0)
631
- down = downconv
632
- up = [uprelu] + upconv
633
- if upnorm is not None:
634
- up += [norm_layer(inner_nc)]
635
- # upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
636
- # upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=0)
637
- # down = downconv
638
- # up = [uprelu] + upconv
639
- # if upnorm is not None:
640
- # up += [norm_layer(outer_nc)]
641
- up +=[uprelu2, uppad, upconv2] #+ [nn.Tanh()]
642
- model = down + [submodule] + up
643
- elif innermost:
644
- upconv = upsampleLayer(inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
645
- upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
646
- down = [downrelu] + downconv
647
- up = [uprelu] + upconv
648
- if upnorm is not None:
649
- up += [upnorm]
650
- up += [uprelu2, uppad, upconv2]
651
- if upnorm2 is not None:
652
- up += [upnorm2]
653
- model = down + up
654
- else:
655
- upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
656
- upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
657
- down = [downrelu] + downconv
658
- if downnorm is not None:
659
- down += [downnorm]
660
- up = [uprelu] + upconv
661
- if upnorm is not None:
662
- up += [upnorm]
663
- up += [uprelu2, uppad, upconv2]
664
- if upnorm2 is not None:
665
- up += [upnorm2]
666
-
667
- if use_dropout:
668
- model = down + [submodule] + up + [nn.Dropout(0.5)]
669
- else:
670
- model = down + [submodule] + up
671
-
672
- self.model = nn.Sequential(*model)
673
-
674
- def forward(self, x):
675
- if self.outermost:
676
- return self.model(x)
677
- else:
678
- x2 = self.model(x)
679
- if self.noise:
680
- x2 = self.noiseblock(x2, self.noise)
681
- return torch.cat([x2, x], 1)
682
-
683
-
684
- class UnetBlock(nn.Module):
685
- def __init__(self, input_nc, outer_nc, inner_nc,
686
- submodule=None, noise=None, outermost=False, innermost=False,
687
- norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='replicate'):
688
- super(UnetBlock, self).__init__()
689
- self.outermost = outermost
690
- p = 0
691
- downconv = []
692
- if padding_type == 'reflect':
693
- downconv += [nn.ReflectionPad2d(1)]
694
- elif padding_type == 'replicate':
695
- downconv += [nn.ReplicationPad2d(1)]
696
- elif padding_type == 'zero':
697
- p = 1
698
- else:
699
- raise NotImplementedError(
700
- 'padding [%s] is not implemented' % padding_type)
701
-
702
- downconv += [nn.Conv2d(input_nc, inner_nc,
703
- kernel_size=3, stride=2, padding=p)]
704
- # downsample is different from upsample
705
- downrelu = nn.LeakyReLU(0.2, True)
706
- downnorm = norm_layer(inner_nc) if norm_layer is not None else None
707
- uprelu = nl_layer()
708
- uprelu2 = nl_layer()
709
- uppad = nn.ReplicationPad2d(1)
710
- upnorm = norm_layer(outer_nc) if norm_layer is not None else None
711
- upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None
712
- self.noiseblock = ApplyNoise(outer_nc)
713
- self.noise = noise
714
-
715
- if outermost:
716
- upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
717
- upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
718
- down = downconv
719
- up = [uprelu] + upconv
720
- if upnorm is not None:
721
- up += [upnorm]
722
- up +=[uprelu2, uppad, upconv2] #+ [nn.Tanh()]
723
- model = down + [submodule] + up
724
- elif innermost:
725
- upconv = upsampleLayer(inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
726
- upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
727
- down = [downrelu] + downconv
728
- up = [uprelu] + upconv
729
- if upnorm is not None:
730
- up += [upnorm]
731
- up += [uprelu2, uppad, upconv2]
732
- if upnorm2 is not None:
733
- up += [upnorm2]
734
- model = down + up
735
- else:
736
- upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
737
- upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
738
- down = [downrelu] + downconv
739
- if downnorm is not None:
740
- down += [downnorm]
741
- up = [uprelu] + upconv
742
- if upnorm is not None:
743
- up += [upnorm]
744
- up += [uprelu2, uppad, upconv2]
745
- if upnorm2 is not None:
746
- up += [upnorm2]
747
-
748
- if use_dropout:
749
- model = down + [submodule] + up + [nn.Dropout(0.5)]
750
- else:
751
- model = down + [submodule] + up
752
-
753
- self.model = nn.Sequential(*model)
754
-
755
- def forward(self, x):
756
- if self.outermost:
757
- return self.model(x)
758
- else:
759
- x2 = self.model(x)
760
- if self.noise:
761
- x2 = self.noiseblock(x2, self.noise)
762
- return torch.cat([x2, x], 1)
763
-
764
- # Defines the submodule with skip connection.
765
- # X -------------------identity---------------------- X
766
- # |-- downsampling -- |submodule| -- upsampling --|
767
- class UnetBlock_A(nn.Module):
768
- def __init__(self, input_nc, outer_nc, inner_nc,
769
- submodule=None, noise=None, outermost=False, innermost=False,
770
- norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='replicate'):
771
- super(UnetBlock_A, self).__init__()
772
- self.outermost = outermost
773
- p = 0
774
- downconv = []
775
- if padding_type == 'reflect':
776
- downconv += [nn.ReflectionPad2d(1)]
777
- elif padding_type == 'replicate':
778
- downconv += [nn.ReplicationPad2d(1)]
779
- elif padding_type == 'zero':
780
- p = 1
781
- else:
782
- raise NotImplementedError(
783
- 'padding [%s] is not implemented' % padding_type)
784
-
785
- downconv += [spectral_norm(nn.Conv2d(input_nc, inner_nc,
786
- kernel_size=3, stride=2, padding=p))]
787
- # downsample is different from upsample
788
- downrelu = nn.LeakyReLU(0.2, True)
789
- downnorm = norm_layer(inner_nc) if norm_layer is not None else None
790
- uprelu = nl_layer()
791
- uprelu2 = nl_layer()
792
- uppad = nn.ReplicationPad2d(1)
793
- upnorm = norm_layer(outer_nc) if norm_layer is not None else None
794
- upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None
795
- self.noiseblock = ApplyNoise(outer_nc)
796
- self.noise = noise
797
-
798
- if outermost:
799
- upconv = upsampleLayer(inner_nc * 1, outer_nc, upsample=upsample, padding_type=padding_type)
800
- upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
801
- down = downconv
802
- up = [uprelu] + upconv
803
- if upnorm is not None:
804
- up += [upnorm]
805
- up +=[uprelu2, uppad, upconv2] #+ [nn.Tanh()]
806
- model = down + [submodule] + up
807
- elif innermost:
808
- upconv = upsampleLayer(inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
809
- upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
810
- down = [downrelu] + downconv
811
- up = [uprelu] + upconv
812
- if upnorm is not None:
813
- up += [upnorm]
814
- up += [uprelu2, uppad, upconv2]
815
- if upnorm2 is not None:
816
- up += [upnorm2]
817
- model = down + up
818
- else:
819
- upconv = upsampleLayer(inner_nc * 1, outer_nc, upsample=upsample, padding_type=padding_type)
820
- upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
821
- down = [downrelu] + downconv
822
- if downnorm is not None:
823
- down += [downnorm]
824
- up = [uprelu] + upconv
825
- if upnorm is not None:
826
- up += [upnorm]
827
- up += [uprelu2, uppad, upconv2]
828
- if upnorm2 is not None:
829
- up += [upnorm2]
830
-
831
- if use_dropout:
832
- model = down + [submodule] + up + [nn.Dropout(0.5)]
833
- else:
834
- model = down + [submodule] + up
835
-
836
- self.model = nn.Sequential(*model)
837
-
838
- def forward(self, x):
839
- if self.outermost:
840
- return self.model(x)
841
- else:
842
- x2 = self.model(x)
843
- if self.noise:
844
- x2 = self.noiseblock(x2, self.noise)
845
- if x2.shape[-1]==x.shape[-1]:
846
- return x2 + x
847
- else:
848
- x2 = F.interpolate(x2, x.shape[2:])
849
- return x2 + x
850
-
851
-
852
- class E_ResNet(nn.Module):
853
- def __init__(self, input_nc=3, output_nc=1, ndf=64, n_blocks=4,
854
- norm_layer=None, nl_layer=None, vaeLike=False):
855
- super(E_ResNet, self).__init__()
856
- self.vaeLike = vaeLike
857
- max_ndf = 4
858
- conv_layers = [
859
- nn.Conv2d(input_nc, ndf, kernel_size=3, stride=2, padding=1, bias=True)]
860
- for n in range(1, n_blocks):
861
- input_ndf = ndf * min(max_ndf, n)
862
- output_ndf = ndf * min(max_ndf, n + 1)
863
- conv_layers += [BasicBlock(input_ndf,
864
- output_ndf, norm_layer, nl_layer)]
865
- conv_layers += [nl_layer(), nn.AdaptiveAvgPool2d(4)]
866
- if vaeLike:
867
- self.fc = nn.Sequential(*[nn.Linear(output_ndf * 16, output_nc)])
868
- self.fcVar = nn.Sequential(*[nn.Linear(output_ndf * 16, output_nc)])
869
- else:
870
- self.fc = nn.Sequential(*[nn.Linear(output_ndf * 16, output_nc)])
871
- self.conv = nn.Sequential(*conv_layers)
872
-
873
- def forward(self, x):
874
- x_conv = self.conv(x)
875
- conv_flat = x_conv.view(x.size(0), -1)
876
- output = self.fc(conv_flat)
877
- if self.vaeLike:
878
- outputVar = self.fcVar(conv_flat)
879
- return output, outputVar
880
- else:
881
- return output
882
- return output
883
-
884
-
885
- # Defines the Unet generator.
886
- # |num_downs|: number of downsamplings in UNet. For example,
887
- # if |num_downs| == 7, image of size 128x128 will become of size 1x1
888
- # at the bottleneck
889
- class G_Unet_add_all(nn.Module):
890
- def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64,
891
- norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False, upsample='basic'):
892
- super(G_Unet_add_all, self).__init__()
893
- self.nz = nz
894
- self.mapping = G_mapping(self.nz, self.nz, 512, normalize_latents=False, lrmul=1)
895
- self.truncation_psi = 0
896
- self.truncation_cutoff = 0
897
-
898
- # - 2 means we start from feature map with height and width equals 4.
899
- # as this example, we get num_layers = 18.
900
- num_layers = int(np.log2(512)) * 2 - 2
901
- # Noise inputs.
902
- self.noise_inputs = []
903
- for layer_idx in range(num_layers):
904
- res = layer_idx // 2 + 2
905
- shape = [1, 1, 2 ** res, 2 ** res]
906
- self.noise_inputs.append(torch.randn(*shape).to("cuda" if torch.cuda.is_available() else "cpu"))
907
-
908
- # construct unet structure
909
- unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=None, innermost=True,
910
- norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
911
- unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=unet_block,
912
- norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
913
- for i in range(num_downs - 6):
914
- unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=unet_block,
915
- norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
916
- unet_block = UnetBlock_with_z(ngf * 4, ngf * 4, ngf * 8, nz, submodule=unet_block,
917
- norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
918
- unet_block = UnetBlock_with_z(ngf * 2, ngf * 2, ngf * 4, nz, submodule=unet_block,
919
- norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
920
- unet_block = UnetBlock_with_z(ngf, ngf, ngf * 2, nz, submodule=unet_block,
921
- norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
922
- unet_block = UnetBlock_with_z(input_nc, output_nc, ngf, nz, submodule=unet_block,
923
- outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
924
- self.model = unet_block
925
-
926
- def forward(self, x, z):
927
-
928
- dlatents1, num_layers = self.mapping(z)
929
- dlatents1 = dlatents1.unsqueeze(1)
930
- dlatents1 = dlatents1.expand(-1, int(num_layers), -1)
931
-
932
- # Apply truncation trick.
933
- if self.truncation_psi and self.truncation_cutoff:
934
- coefs = np.ones([1, num_layers, 1], dtype=np.float32)
935
- for i in range(num_layers):
936
- if i < self.truncation_cutoff:
937
- coefs[:, i, :] *= self.truncation_psi
938
- """Linear interpolation.
939
- a + (b - a) * t (a = 0)
940
- reduce to
941
- b * t
942
- """
943
- dlatents1 = dlatents1 * torch.Tensor(coefs).to(dlatents1.device)
944
-
945
- return torch.tanh(self.model(x, dlatents1, self.noise_inputs))
946
-
947
-
948
- class ApplyNoise(nn.Module):
949
- def __init__(self, channels):
950
- super().__init__()
951
- self.channels = channels
952
- self.weight = nn.Parameter(torch.randn(channels), requires_grad=True)
953
- self.bias = nn.Parameter(torch.zeros(channels), requires_grad=True)
954
-
955
- def forward(self, x, noise):
956
- W,_ = torch.split(self.weight.view(1, -1, 1, 1), self.channels // 2, dim=1)
957
- B,_ = torch.split(self.bias.view(1, -1, 1, 1), self.channels // 2, dim=1)
958
- Z = torch.zeros_like(W)
959
- w = torch.cat([W,Z], dim=1).to(x.device)
960
- b = torch.cat([B,Z], dim=1).to(x.device)
961
- adds = w * torch.randn_like(x) + b
962
- return x + adds.type_as(x)
963
-
964
-
965
- class FC(nn.Module):
966
- def __init__(self,
967
- in_channels,
968
- out_channels,
969
- gain=2**(0.5),
970
- use_wscale=False,
971
- lrmul=1.0,
972
- bias=True):
973
- """
974
- The complete conversion of Dense/FC/Linear Layer of original Tensorflow version.
975
- """
976
- super(FC, self).__init__()
977
- he_std = gain * in_channels ** (-0.5) # He init
978
- if use_wscale:
979
- init_std = 1.0 / lrmul
980
- self.w_lrmul = he_std * lrmul
981
- else:
982
- init_std = he_std / lrmul
983
- self.w_lrmul = lrmul
984
-
985
- self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels) * init_std)
986
- if bias:
987
- self.bias = torch.nn.Parameter(torch.zeros(out_channels))
988
- self.b_lrmul = lrmul
989
- else:
990
- self.bias = None
991
-
992
- def forward(self, x):
993
- if self.bias is not None:
994
- out = F.linear(x, self.weight * self.w_lrmul, self.bias * self.b_lrmul)
995
- else:
996
- out = F.linear(x, self.weight * self.w_lrmul)
997
- out = F.leaky_relu(out, 0.2, inplace=True)
998
- return out
999
-
1000
-
1001
- class ApplyStyle(nn.Module):
1002
  """
1003
- @ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1004
  """
1005
- def __init__(self, latent_size, channels, use_wscale, nl_layer):
1006
- super(ApplyStyle, self).__init__()
1007
- modules = [nn.Linear(latent_size, channels*2)]
1008
- if nl_layer:
1009
- modules += [nl_layer()]
1010
- self.linear = nn.Sequential(*modules)
1011
-
1012
- def forward(self, x, latent):
1013
- style = self.linear(latent) # style => [batch_size, n_channels*2]
1014
- shape = [-1, 2, x.size(1), 1, 1]
1015
- style = style.view(shape) # [batch_size, 2, n_channels, ...]
1016
- x = x * (style[:, 0] + 1.) + style[:, 1]
1017
- return x
1018
-
1019
- class PixelNorm(nn.Module):
1020
- def __init__(self, epsilon=1e-8):
1021
- """
1022
- @notice: avoid in-place ops.
1023
- https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
1024
- """
1025
- super(PixelNorm, self).__init__()
1026
- self.epsilon = epsilon
1027
-
1028
- def forward(self, x):
1029
- tmp = torch.mul(x, x) # or x ** 2
1030
- tmp1 = torch.rsqrt(torch.mean(tmp, dim=1, keepdim=True) + self.epsilon)
1031
-
1032
- return x * tmp1
1033
-
1034
-
1035
- class InstanceNorm(nn.Module):
1036
- def __init__(self, epsilon=1e-8):
1037
- """
1038
- @notice: avoid in-place ops.
1039
- https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
1040
- """
1041
- super(InstanceNorm, self).__init__()
1042
- self.epsilon = epsilon
1043
-
1044
- def forward(self, x):
1045
- x = x - torch.mean(x, (2, 3), True)
1046
- tmp = torch.mul(x, x) # or x ** 2
1047
- tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
1048
- return x * tmp
1049
-
1050
-
1051
- class LayerEpilogue(nn.Module):
1052
- def __init__(self, channels, dlatent_size, use_wscale, use_noise,
1053
- use_pixel_norm, use_instance_norm, use_styles, nl_layer=None):
1054
- super(LayerEpilogue, self).__init__()
1055
- self.use_noise = use_noise
1056
- if use_noise:
1057
- self.noise = ApplyNoise(channels)
1058
- self.act = nn.LeakyReLU(negative_slope=0.2)
1059
-
1060
- if use_pixel_norm:
1061
- self.pixel_norm = PixelNorm()
1062
- else:
1063
- self.pixel_norm = None
1064
-
1065
- if use_instance_norm:
1066
- self.instance_norm = InstanceNorm()
1067
- else:
1068
- self.instance_norm = None
1069
-
1070
- if use_styles:
1071
- self.style_mod = ApplyStyle(dlatent_size, channels, use_wscale=use_wscale, nl_layer=nl_layer)
1072
- else:
1073
- self.style_mod = None
1074
-
1075
- def forward(self, x, noise, dlatents_in_slice=None):
1076
- # if noise is not None:
1077
- if self.use_noise:
1078
- x = self.noise(x, noise)
1079
- x = self.act(x)
1080
- if self.pixel_norm is not None:
1081
- x = self.pixel_norm(x)
1082
- if self.instance_norm is not None:
1083
- x = self.instance_norm(x)
1084
- if self.style_mod is not None:
1085
- x = self.style_mod(x, dlatents_in_slice)
1086
-
1087
- return x
1088
-
1089
- class G_mapping(nn.Module):
1090
- def __init__(self,
1091
- mapping_fmaps=512,
1092
- dlatent_size=512,
1093
- resolution=512,
1094
- normalize_latents=True, # Normalize latent vectors (Z) before feeding them to the mapping layers?
1095
- use_wscale=True, # Enable equalized learning rate?
1096
- lrmul=0.01, # Learning rate multiplier for the mapping layers.
1097
- gain=2**(0.5), # original gain in tensorflow.
1098
- nl_layer=None
1099
- ):
1100
- super(G_mapping, self).__init__()
1101
- self.mapping_fmaps = mapping_fmaps
1102
- func = [
1103
- nn.Linear(self.mapping_fmaps, dlatent_size)
1104
- ]
1105
- if nl_layer:
1106
- func += [nl_layer()]
1107
-
1108
- for j in range(0,4):
1109
- func += [
1110
- nn.Linear(dlatent_size, dlatent_size)
1111
- ]
1112
- if nl_layer:
1113
- func += [nl_layer()]
1114
-
1115
- self.func = nn.Sequential(*func)
1116
- #FC(self.mapping_fmaps, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale),
1117
- #FC(dlatent_size, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale),
1118
-
1119
- self.normalize_latents = normalize_latents
1120
- self.resolution_log2 = int(np.log2(resolution))
1121
- self.num_layers = self.resolution_log2 * 2 - 2
1122
- self.pixel_norm = PixelNorm()
1123
- # - 2 means we start from feature map with height and width equals 4.
1124
- # as this example, we get num_layers = 18.
1125
-
1126
- def forward(self, x):
1127
- if self.normalize_latents:
1128
- x = self.pixel_norm(x)
1129
- out = self.func(x)
1130
- return out, self.num_layers
1131
-
1132
- class UnetBlock_with_z(nn.Module):
1133
- def __init__(self, input_nc, outer_nc, inner_nc, nz=0,
1134
- submodule=None, outermost=False, innermost=False,
1135
- norm_layer=None, nl_layer=None, use_dropout=False,
1136
- upsample='basic', padding_type='replicate'):
1137
- super(UnetBlock_with_z, self).__init__()
1138
- p = 0
1139
- downconv = []
1140
- if padding_type == 'reflect':
1141
- downconv += [nn.ReflectionPad2d(1)]
1142
- elif padding_type == 'replicate':
1143
- downconv += [nn.ReplicationPad2d(1)]
1144
- elif padding_type == 'zero':
1145
- p = 1
1146
- else:
1147
- raise NotImplementedError(
1148
- 'padding [%s] is not implemented' % padding_type)
1149
-
1150
- self.outermost = outermost
1151
- self.innermost = innermost
1152
- self.nz = nz
1153
-
1154
- # input_nc = input_nc + nz
1155
- downconv += [spectral_norm(nn.Conv2d(input_nc, inner_nc,
1156
- kernel_size=3, stride=2, padding=p))]
1157
- # downsample is different from upsample
1158
- downrelu = nn.LeakyReLU(0.2, True)
1159
- downnorm = norm_layer(inner_nc) if norm_layer is not None else None
1160
- uprelu = nl_layer()
1161
- uprelu2 = nl_layer()
1162
- uppad = nn.ReplicationPad2d(1)
1163
- upnorm = norm_layer(outer_nc) if norm_layer is not None else None
1164
- upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None
1165
-
1166
- use_styles=False
1167
- uprelu = nl_layer()
1168
- if self.nz >0:
1169
- use_styles=True
1170
-
1171
- if outermost:
1172
- self.adaIn = LayerEpilogue(inner_nc, self.nz, use_wscale=True, use_noise=False,
1173
- use_pixel_norm=True, use_instance_norm=True, use_styles=use_styles, nl_layer=nl_layer)
1174
- upconv = upsampleLayer(
1175
- inner_nc , outer_nc, upsample=upsample, padding_type=padding_type)
1176
- upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
1177
- down = downconv
1178
- up = [uprelu] + upconv
1179
- if upnorm is not None:
1180
- up += [upnorm]
1181
- up +=[uprelu2, uppad, upconv2] #+ [nn.Tanh()]
1182
- elif innermost:
1183
- self.adaIn = LayerEpilogue(inner_nc, self.nz, use_wscale=True, use_noise=True,
1184
- use_pixel_norm=True, use_instance_norm=True, use_styles=use_styles, nl_layer=nl_layer)
1185
- upconv = upsampleLayer(
1186
- inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
1187
- upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
1188
- down = [downrelu] + downconv
1189
- up = [uprelu] + upconv
1190
- if norm_layer is not None:
1191
- up += [norm_layer(outer_nc)]
1192
- up += [uprelu2, uppad, upconv2]
1193
- if upnorm2 is not None:
1194
- up += [upnorm2]
1195
- else:
1196
- self.adaIn = LayerEpilogue(inner_nc, self.nz, use_wscale=True, use_noise=False,
1197
- use_pixel_norm=True, use_instance_norm=True, use_styles=use_styles, nl_layer=nl_layer)
1198
- upconv = upsampleLayer(
1199
- inner_nc , outer_nc, upsample=upsample, padding_type=padding_type)
1200
- upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
1201
- down = [downrelu] + downconv
1202
- if norm_layer is not None:
1203
- down += [norm_layer(inner_nc)]
1204
- up = [uprelu] + upconv
1205
-
1206
- if norm_layer is not None:
1207
- up += [norm_layer(outer_nc)]
1208
- up += [uprelu2, uppad, upconv2]
1209
- if upnorm2 is not None:
1210
- up += [upnorm2]
1211
-
1212
- if use_dropout:
1213
- up += [nn.Dropout(0.5)]
1214
- self.down = nn.Sequential(*down)
1215
- self.submodule = submodule
1216
- self.up = nn.Sequential(*up)
1217
-
1218
-
1219
- def forward(self, x, z, noise):
1220
- if self.outermost:
1221
- x1 = self.down(x)
1222
- x2 = self.submodule(x1, z[:,2:], noise[2:])
1223
- return self.up(x2)
1224
-
1225
- elif self.innermost:
1226
- x1 = self.down(x)
1227
- x_and_z = self.adaIn(x1, noise[0], z[:,0])
1228
- x2 = self.up(x_and_z)
1229
- x2 = F.interpolate(x2, x.shape[2:])
1230
- return x2 + x
1231
-
1232
- else:
1233
- x1 = self.down(x)
1234
- x2 = self.submodule(x1, z[:,2:], noise[2:])
1235
- x_and_z = self.adaIn(x2, noise[0], z[:,0])
1236
- return self.up(x_and_z) + x
1237
-
1238
-
1239
- class E_NLayers(nn.Module):
1240
- def __init__(self, input_nc, output_nc=1, ndf=64, n_layers=4,
1241
- norm_layer=None, nl_layer=None, vaeLike=False):
1242
- super(E_NLayers, self).__init__()
1243
- self.vaeLike = vaeLike
1244
-
1245
- kw, padw = 3, 1
1246
- sequence = [spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw,
1247
- stride=2, padding=padw, padding_mode='replicate')), nl_layer()]
1248
-
1249
- nf_mult = 1
1250
- nf_mult_prev = 1
1251
- for n in range(1, n_layers):
1252
- nf_mult_prev = nf_mult
1253
- nf_mult = min(2**n, 8)
1254
- sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
1255
- kernel_size=kw, stride=2, padding=padw, padding_mode='replicate'))]
1256
- if norm_layer is not None:
1257
- sequence += [norm_layer(ndf * nf_mult)]
1258
- sequence += [nl_layer()]
1259
- sequence += [nn.AdaptiveAvgPool2d(4)]
1260
- self.conv = nn.Sequential(*sequence)
1261
- self.fc = nn.Sequential(*[spectral_norm(nn.Linear(ndf * nf_mult * 16, output_nc))])
1262
- if vaeLike:
1263
- self.fcVar = nn.Sequential(*[spectral_norm(nn.Linear(ndf * nf_mult * 16, output_nc))])
1264
-
1265
- def forward(self, x):
1266
- x_conv = self.conv(x)
1267
- conv_flat = x_conv.view(x.size(0), -1)
1268
- output = self.fc(conv_flat)
1269
- if self.vaeLike:
1270
- outputVar = self.fcVar(conv_flat)
1271
- return output, outputVar
1272
- return output
1273
-
1274
- class BasicBlock(nn.Module):
1275
- def __init__(self, inplanes, outplanes):
1276
- super(BasicBlock, self).__init__()
1277
- layers = []
1278
- norm_layer=get_norm_layer(norm_type='layer') #functools.partial(LayerNorm)
1279
- # norm_layer = None
1280
- nl_layer=nn.ReLU()
1281
- if norm_layer is not None:
1282
- layers += [norm_layer(inplanes)]
1283
- layers += [nl_layer]
1284
- layers += [nn.ReplicationPad2d(1),
1285
- nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=1,
1286
- padding=0, bias=True)]
1287
- self.conv = nn.Sequential(*layers)
1288
-
1289
- def forward(self, x):
1290
- return self.conv(x)
1291
-
1292
-
1293
- def define_SVAE(inc=96, outc=3, outplanes=64, blocks=1, netVAE='SVAE', model_name='', load_ext=True, save_dir='',
1294
- init_type="normal", init_gain=0.02, gpu_ids=[]):
1295
- if netVAE == 'SVAE':
1296
- net = ScreenVAE(inc=inc, outc=outc, outplanes=outplanes, blocks=blocks, save_dir=save_dir,
1297
- init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids)
1298
- else:
1299
- raise NotImplementedError('Encoder model name [%s] is not recognized' % net)
1300
- init_net(net, init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids)
1301
- net.load_networks('latest')
1302
- return net
1303
-
1304
-
1305
- class ScreenVAE(nn.Module):
1306
- def __init__(self,inc=1,outc=4, outplanes=64, downs=5, blocks=2,load_ext=True, save_dir='',init_type="normal", init_gain=0.02, gpu_ids=[]):
1307
- super(ScreenVAE, self).__init__()
1308
- self.inc = inc
1309
- self.outc = outc
1310
- self.save_dir = save_dir
1311
- norm_layer=functools.partial(LayerNormWarpper)
1312
- nl_layer=nn.LeakyReLU
1313
-
1314
- self.model_names=['enc','dec']
1315
- self.enc=define_C(inc+1, outc*2, 0, 24, netC='resnet_6blocks',
1316
- norm='layer', nl='lrelu', use_dropout=True, init_type='kaiming',
1317
- gpu_ids=gpu_ids, upsample='bilinear')
1318
- self.dec=define_G(outc, inc, 0, 48, netG='unet_128_G',
1319
- norm='layer', nl='lrelu', use_dropout=True, init_type='kaiming',
1320
- gpu_ids=gpu_ids, where_add='input', upsample='bilinear', use_noise=True)
1321
-
1322
- for param in self.parameters():
1323
- param.requires_grad = False
1324
-
1325
- def load_networks(self, epoch):
1326
- """Load all the networks from the disk.
1327
-
1328
- Parameters:
1329
- epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
1330
- """
1331
- for name in self.model_names:
1332
- if isinstance(name, str):
1333
- load_filename = '%s_net_%s.pth' % (epoch, name)
1334
- load_path = os.path.join(self.save_dir, load_filename)
1335
- net = getattr(self, name)
1336
- if isinstance(net, torch.nn.DataParallel):
1337
- net = net.module
1338
- print('loading the model from %s' % load_path)
1339
- state_dict = torch.load(
1340
- load_path, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
1341
- if hasattr(state_dict, '_metadata'):
1342
- del state_dict._metadata
1343
-
1344
- net.load_state_dict(state_dict)
1345
- del state_dict
1346
-
1347
- def npad(self, im, pad=128):
1348
- h,w = im.shape[-2:]
1349
- hp = h //pad*pad+pad
1350
- wp = w //pad*pad+pad
1351
- return F.pad(im, (0, wp-w, 0, hp-h), mode='replicate')
1352
-
1353
- def forward(self, x, line=None, img_input=True, output_screen_only=True):
1354
- if img_input:
1355
- if line is None:
1356
- line = torch.ones_like(x)
1357
- else:
1358
- line = torch.sign(line)
1359
- x = torch.clamp(x + (1-line),-1,1)
1360
- h,w = x.shape[-2:]
1361
- input = torch.cat([x, line], 1)
1362
- input = self.npad(input)
1363
- inter = self.enc(input)[:,:,:h,:w]
1364
- scr, logvar = torch.split(inter, (self.outc, self.outc), dim=1)
1365
- if output_screen_only:
1366
- return scr
1367
- recons = self.dec(scr)
1368
- return recons, scr, logvar
1369
- else:
1370
- h,w = x.shape[-2:]
1371
- x = self.npad(x)
1372
- recons = self.dec(x)[:,:,:h,:w]
1373
- recons = (recons+1)*(line+1)/2-1
1374
- return torch.clamp(recons,-1,1)
 
1
+ import spaces
2
+ import contextlib
3
+ import gc
4
+ import json
5
+ import logging
6
+ import math
7
+ import os
8
+ import random
9
+ import shutil
10
+ import sys
11
+ import time
12
+ import itertools
13
+ from pathlib import Path
14
+
15
+ import cv2
16
  import numpy as np
17
+ from PIL import Image, ImageDraw
18
+ import torch
19
  import torch.nn.functional as F
20
+ import torch.utils.checkpoint
21
+ from torch.utils.data import Dataset
22
+ from torchvision import transforms
23
+ from tqdm.auto import tqdm
24
+
25
+ import accelerate
26
+ from accelerate import Accelerator
27
+ from accelerate.logging import get_logger
28
+ from accelerate.utils import ProjectConfiguration, set_seed
29
+
30
+ from datasets import load_dataset
31
+ from huggingface_hub import create_repo, upload_folder
32
+ from packaging import version
33
+ from safetensors.torch import load_model
34
+ from peft import LoraConfig
35
+ import gradio as gr
36
+ import pandas as pd
37
+
38
+ import transformers
39
+ from transformers import (
40
+ AutoTokenizer,
41
+ PretrainedConfig,
42
+ CLIPVisionModelWithProjection,
43
+ CLIPImageProcessor,
44
+ CLIPProcessor,
45
+ )
46
+
47
+ import diffusers
48
+ from diffusers import (
49
+ AutoencoderKL,
50
+ DDPMScheduler,
51
+ ColorGuiderPixArtModel,
52
+ ColorGuiderSDModel,
53
+ UNet2DConditionModel,
54
+ PixArtTransformer2DModel,
55
+ ColorFlowPixArtAlphaPipeline,
56
+ ColorFlowSDPipeline,
57
+ UniPCMultistepScheduler,
58
+ )
59
+ from colorflow_utils.utils import *
60
+
61
+ sys.path.append('./BidirectionalTranslation')
62
+ from options.test_options import TestOptions
63
+ from models import create_model
64
+ from util import util
65
+
66
+ from huggingface_hub import snapshot_download
67
+
68
+
69
+ article = r"""
70
+ If ColorFlow is helpful, please help to ⭐ the <a href='https://github.com/TencentARC/ColorFlow' target='_blank'>Github Repo</a>. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/TencentARC/ColorFlow)](https://github.com/TencentARC/ColorFlow)
71
+ ---
72
+
73
+ 📧 **Contact**
74
+ <br>
75
+ If you have any questions, please feel free to reach me out at <b>zhuangjh23@mails.tsinghua.edu.cn</b>.
76
+
77
+ 📝 **Citation**
78
+ <br>
79
+ If our work is useful for your research, please consider citing:
80
+ ```bibtex
81
+ @misc{zhuang2024colorflow,
82
+ title={ColorFlow: Retrieval-Augmented Image Sequence Colorization},
83
+ author={Junhao Zhuang and Xuan Ju and Zhaoyang Zhang and Yong Liu and Shiyi Zhang and Chun Yuan and Ying Shan},
84
+ year={2024},
85
+ eprint={2412.11815},
86
+ archivePrefix={arXiv},
87
+ primaryClass={cs.CV},
88
+ url={https://arxiv.org/abs/2412.11815},
89
+ }
90
+ ```
91
+ """
92
+
93
+ model_global_path = snapshot_download(repo_id="TencentARC/ColorFlow", cache_dir='./colorflow/', repo_type="model")
94
+ print(model_global_path)
95
+
96
+
97
+ transform = transforms.Compose([
98
+ transforms.ToTensor(),
99
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
100
+ ])
101
+ weight_dtype = torch.float16
102
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
103
+
104
+ # line model
105
+ line_model_path = model_global_path + '/LE/erika.pth'
106
+ line_model = res_skip()
107
+ line_model.load_state_dict(torch.load(line_model_path))
108
+ line_model.eval()
109
+ line_model.to(device)
110
+
111
+ # screen model
112
+ global opt
113
+
114
+ opt = TestOptions().parse(model_global_path)
115
+ ScreenModel = create_model(opt, model_global_path)
116
+ ScreenModel.setup(opt)
117
+ ScreenModel.eval()
118
+
119
+ image_processor = CLIPImageProcessor()
120
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(model_global_path + '/image_encoder/').to(device)
121
+
122
+
123
+ examples = [
124
+ [
125
+ "./assets/example_5/input.png",
126
+ ["./assets/example_5/ref1.png", "./assets/example_5/ref2.png", "./assets/example_5/ref3.png"],
127
+ "GrayImage(ScreenStyle)",
128
+ "800x512",
129
+ 0,
130
+ 10
131
+ ],
132
+ [
133
+ "./assets/example_4/input.jpg",
134
+ ["./assets/example_4/ref1.jpg", "./assets/example_4/ref2.jpg", "./assets/example_4/ref3.jpg"],
135
+ "GrayImage(ScreenStyle)",
136
+ "640x640",
137
+ 0,
138
+ 10
139
+ ],
140
+ [
141
+ "./assets/example_3/input.png",
142
+ ["./assets/example_3/ref1.png", "./assets/example_3/ref2.png", "./assets/example_3/ref3.png"],
143
+ "GrayImage(ScreenStyle)",
144
+ "800x512",
145
+ 0,
146
+ 10
147
+ ],
148
+ [
149
+ "./assets/example_2/input.png",
150
+ ["./assets/example_2/ref1.png", "./assets/example_2/ref2.png", "./assets/example_2/ref3.png"],
151
+ "GrayImage(ScreenStyle)",
152
+ "800x512",
153
+ 0,
154
+ 10
155
+ ],
156
+ [
157
+ "./assets/example_6/input.png",
158
+ ["./assets/example_6/ref1.png", "./assets/example_6/ref2.png", "./assets/example_6/ref3.png"],
159
+ "Sketch_Shading",
160
+ "512x800",
161
+ 0,
162
+ 10
163
+ ],
164
+ [
165
+ "./assets/example_7/input.jpg",
166
+ ["./assets/example_7/ref1.jpg", "./assets/example_7/ref2.jpg", "./assets/example_7/ref3.jpg", "./assets/example_7/ref4.jpg"],
167
+ "Sketch_Shading",
168
+ "640x640",
169
+ 2,
170
+ 10
171
+ ],
172
+ [
173
+ "./assets/example_1/input.jpg",
174
+ ["./assets/example_1/ref1.jpg", "./assets/example_1/ref2.jpg", "./assets/example_1/ref3.jpg"],
175
+ "Sketch",
176
+ "640x640",
177
+ 1,
178
+ 10
179
+ ],
180
+ [
181
+ "./assets/example_0/input.jpg",
182
+ ["./assets/example_0/ref1.jpg"],
183
+ "Sketch",
184
+ "640x640",
185
+ 1,
186
+ 10
187
+ ],
188
+ ]
189
+
190
+ global pipeline
191
+ global MultiResNetModel
192
+
193
+ @spaces.GPU
194
+ def load_ckpt(input_style):
195
+ global pipeline
196
+ global MultiResNetModel
197
+ if input_style == "Sketch" or input_style == "Sketch_Shading":
198
+ if input_style == "Sketch":
199
+ ckpt_path = model_global_path + '/sketch/'
200
+ rank = 128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  else:
202
+ ckpt_path = model_global_path + '/shading/'
203
+ rank = 128
204
+ pretrained_model_name_or_path = 'PixArt-alpha/PixArt-XL-2-1024-MS'
205
+ transformer = PixArtTransformer2DModel.from_pretrained(
206
+ pretrained_model_name_or_path, subfolder="transformer", revision=None, variant=None
207
+ )
208
+ pixart_config = get_pixart_config()
209
+
210
+ ColorGuider = ColorGuiderPixArtModel.from_pretrained(ckpt_path)
211
+
212
+ transformer_lora_config = LoraConfig(
213
+ r=rank,
214
+ lora_alpha=rank,
215
+ init_lora_weights="gaussian",
216
+ target_modules=["to_k", "to_q", "to_v", "to_out.0", "proj_in", "proj_out", "ff.net.0.proj", "ff.net.2", "proj", "linear", "linear_1", "linear_2"]
217
+ )
218
+ transformer.add_adapter(transformer_lora_config)
219
+ ckpt_key_t = torch.load(ckpt_path + 'transformer_lora.bin', map_location='cpu')
220
+ transformer.load_state_dict(ckpt_key_t, strict=False)
221
+
222
+ transformer.to(device, dtype=weight_dtype)
223
+ ColorGuider.to(device, dtype=weight_dtype)
224
+
225
+ pipeline = ColorFlowPixArtAlphaPipeline.from_pretrained(
226
+ pretrained_model_name_or_path,
227
+ transformer=transformer,
228
+ colorguider=ColorGuider,
229
+ safety_checker=None,
230
+ revision=None,
231
+ variant=None,
232
+ torch_dtype=weight_dtype,
233
+ )
234
+ pipeline = pipeline.to(device)
235
+ block_out_channels = [128, 128, 256, 512, 512]
236
+
237
+ MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
238
+ MultiResNetModel.load_state_dict(torch.load(ckpt_path + 'MultiResNetModel.bin', map_location='cpu'), strict=False)
239
+ MultiResNetModel.to(device, dtype=weight_dtype)
240
+
241
+ elif input_style == "GrayImage(ScreenStyle)":
242
+ ckpt_path = model_global_path + '/GraySD/'
243
+ rank = 64
244
+ pretrained_model_name_or_path = 'stable-diffusion-v1-5/stable-diffusion-v1-5'
245
+ unet = UNet2DConditionModel.from_pretrained(
246
+ pretrained_model_name_or_path, subfolder="unet", revision=None, variant=None
247
+ )
248
+ ColorGuider = ColorGuiderSDModel.from_pretrained(ckpt_path)
249
+ ColorGuider.to(device, dtype=weight_dtype)
250
+ unet.to(device, dtype=weight_dtype)
251
+
252
+ pipeline = ColorFlowSDPipeline.from_pretrained(
253
+ pretrained_model_name_or_path,
254
+ unet=unet,
255
+ colorguider=ColorGuider,
256
+ safety_checker=None,
257
+ revision=None,
258
+ variant=None,
259
+ torch_dtype=weight_dtype,
260
+ )
261
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
262
+ unet_lora_config = LoraConfig(
263
+ r=rank,
264
+ lora_alpha=rank,
265
+ init_lora_weights="gaussian",
266
+ target_modules=["to_k", "to_q", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2"],#ff.net.0.proj ff.net.2
267
+ )
268
+ pipeline.unet.add_adapter(unet_lora_config)
269
+ pipeline.unet.load_state_dict(torch.load(ckpt_path + 'unet_lora.bin', map_location='cpu'), strict=False)
270
+ pipeline = pipeline.to(device)
271
+ block_out_channels = [128, 128, 256, 512, 512]
272
+
273
+ MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
274
+ MultiResNetModel.load_state_dict(torch.load(ckpt_path + 'MultiResNetModel.bin', map_location='cpu'), strict=False)
275
+ MultiResNetModel.to(device, dtype=weight_dtype)
276
+
277
+
278
+
279
+
280
+
281
+ global cur_input_style
282
+ cur_input_style = "Sketch"
283
+ load_ckpt(cur_input_style)
284
+ cur_input_style = "Sketch_Shading"
285
+ load_ckpt(cur_input_style)
286
+ cur_input_style = "GrayImage(ScreenStyle)"
287
+ load_ckpt(cur_input_style)
288
+ cur_input_style = None
289
+
290
+ @spaces.GPU
291
+ def fix_random_seeds(seed):
292
+ random.seed(seed)
293
+ np.random.seed(seed)
294
+ torch.manual_seed(seed)
295
+ if torch.cuda.is_available():
296
+ torch.cuda.manual_seed(seed)
297
+ torch.cuda.manual_seed_all(seed)
298
+
299
+ def process_multi_images(files):
300
+ images = [Image.open(file.name) for file in files]
301
+ imgs = []
302
+ for i, img in enumerate(images):
303
+ imgs.append(img)
304
+ return imgs
305
+
306
+ @spaces.GPU
307
+ def extract_lines(image):
308
+ src = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
309
+
310
+ rows = int(np.ceil(src.shape[0] / 16)) * 16
311
+ cols = int(np.ceil(src.shape[1] / 16)) * 16
312
+
313
+ patch = np.ones((1, 1, rows, cols), dtype="float32")
314
+ patch[0, 0, 0:src.shape[0], 0:src.shape[1]] = src
315
+
316
+ tensor = torch.from_numpy(patch).to(device)
317
+
318
+ with torch.no_grad():
319
+ y = line_model(tensor)
320
+
321
+ yc = y.cpu().numpy()[0, 0, :, :]
322
+ yc[yc > 255] = 255
323
+ yc[yc < 0] = 0
324
+
325
+ outimg = yc[0:src.shape[0], 0:src.shape[1]]
326
+ outimg = outimg.astype(np.uint8)
327
+ outimg = Image.fromarray(outimg)
328
+ torch.cuda.empty_cache()
329
+ return outimg
330
+
331
+ @spaces.GPU
332
+ def to_screen_image(input_image):
333
+ global opt
334
+ global ScreenModel
335
+ input_image = input_image.convert('RGB')
336
+ input_image = get_ScreenVAE_input(input_image, opt)
337
+ h = input_image['h']
338
+ w = input_image['w']
339
+ ScreenModel.set_input(input_image)
340
+ fake_B, fake_B2, SCR = ScreenModel.forward(AtoB=True)
341
+ images=fake_B2[:,:,:h,:w]
342
+ im = util.tensor2im(images)
343
+ image_pil = Image.fromarray(im)
344
+ torch.cuda.empty_cache()
345
+ return image_pil
346
+
347
+ @spaces.GPU
348
+ def extract_line_image(query_image_, input_style, resolution):
349
+ if resolution == "640x640":
350
+ tar_width = 640
351
+ tar_height = 640
352
+ elif resolution == "512x800":
353
+ tar_width = 512
354
+ tar_height = 800
355
+ elif resolution == "800x512":
356
+ tar_width = 800
357
+ tar_height = 512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  else:
359
+ gr.Info("Unsupported resolution")
360
+
361
+ query_image = process_image(query_image_, int(tar_width*1.5), int(tar_height*1.5))
362
+ if input_style == "GrayImage(ScreenStyle)":
363
+ extracted_line = to_screen_image(query_image)
364
+ extracted_line = Image.blend(extracted_line.convert('L').convert('RGB'), query_image.convert('L').convert('RGB'), 0.5)
365
+ input_context = extracted_line
366
+ elif input_style == "Sketch":
367
+ query_image = query_image.convert('L').convert('RGB')
368
+ extracted_line = extract_lines(query_image)
369
+ extracted_line = extracted_line.convert('L').convert('RGB')
370
+ input_context = extracted_line
371
+ elif input_style == "Sketch_Shading":
372
+ query_image = query_image.convert('L').convert('RGB')
373
+ extracted_line = extract_lines(query_image)
374
+ extracted_line = extracted_line.convert('L').convert('RGB')
375
+ array1 = np.array(query_image)
376
+ array2 = np.array(extracted_line)
377
+ array2[array1 < 0.3 * 255.0] = 0
378
+ gray_rate = 125
379
+ up_bound = 145
380
+ array2[(array2 > gray_rate) & (array1 < up_bound) & (array1 > 0.3 * 255.0)] = gray_rate
381
+ input_context = Image.fromarray(np.uint8(array2))
382
+ torch.cuda.empty_cache()
383
+ return input_context, extracted_line, input_context
384
+
385
+ @spaces.GPU(duration=180)
386
+ def colorize_image(VAE_input, input_context, reference_images, resolution, seed, input_style, num_inference_steps):
387
+ if VAE_input is None or input_context is None:
388
+ gr.Info("Please preprocess the image first")
389
+ raise ValueError("Please preprocess the image first")
390
+ global cur_input_style
391
+ global pipeline
392
+ global MultiResNetModel
393
+ if input_style != cur_input_style:
394
+ gr.Info(f"Loading {input_style} model...")
395
+ load_ckpt(input_style)
396
+ cur_input_style = input_style
397
+ gr.Info(f"{input_style} model loaded")
398
+ reference_images = process_multi_images(reference_images)
399
+ fix_random_seeds(seed)
400
+ if resolution == "640x640":
401
+ tar_width = 640
402
+ tar_height = 640
403
+ elif resolution == "512x800":
404
+ tar_width = 512
405
+ tar_height = 800
406
+ elif resolution == "800x512":
407
+ tar_width = 800
408
+ tar_height = 512
409
+ else:
410
+ gr.Info("Unsupported resolution")
411
+ validation_mask = Image.open('./assets/mask.png').convert('RGB').resize((tar_width*2, tar_height*2))
412
+ gr.Info("Image retrieval in progress...")
413
+ query_image_bw = process_image(input_context, int(tar_width), int(tar_height))
414
+ query_image = query_image_bw.convert('RGB')
415
+ query_image_vae = process_image(VAE_input, int(tar_width*1.5), int(tar_height*1.5))
416
+ reference_images = [process_image(ref_image, tar_width, tar_height) for ref_image in reference_images]
417
+ query_patches_pil = process_image_Q_varres(query_image, tar_width, tar_height)
418
+ reference_patches_pil = []
419
+ for reference_image in reference_images:
420
+ reference_patches_pil += process_image_ref_varres(reference_image, tar_width, tar_height)
421
+ combined_image = None
422
+ with torch.no_grad():
423
+ clip_img = image_processor(images=query_patches_pil, return_tensors="pt").pixel_values.to(image_encoder.device, dtype=image_encoder.dtype)
424
+ query_embeddings = image_encoder(clip_img).image_embeds
425
+ reference_patches_pil_gray = [rimg.convert('RGB').convert('RGB') for rimg in reference_patches_pil]
426
+ clip_img = image_processor(images=reference_patches_pil_gray, return_tensors="pt").pixel_values.to(image_encoder.device, dtype=image_encoder.dtype)
427
+ reference_embeddings = image_encoder(clip_img).image_embeds
428
+ cosine_similarities = F.cosine_similarity(query_embeddings.unsqueeze(1), reference_embeddings.unsqueeze(0), dim=-1)
429
+ sorted_indices = torch.argsort(cosine_similarities, descending=True, dim=1).tolist()
430
+ top_k = 3
431
+ top_k_indices = [cur_sortlist[:top_k] for cur_sortlist in sorted_indices]
432
+ combined_image = Image.new('RGB', (tar_width * 2, tar_height * 2), 'white')
433
+ combined_image.paste(query_image_bw.resize((tar_width, tar_height)), (tar_width//2, tar_height//2))
434
+ idx_table = {0:[(1,0), (0,1), (0,0)], 1:[(1,3), (0,2),(0,3)], 2:[(2,0),(3,1), (3,0)], 3:[(2,3), (3,2),(3,3)]}
435
+ for i in range(2):
436
+ for j in range(2):
437
+ idx_list = idx_table[i * 2 + j]
438
+ for k in range(top_k):
439
+ ref_index = top_k_indices[i * 2 + j][k]
440
+ idx_y = idx_list[k][0]
441
+ idx_x = idx_list[k][1]
442
+ combined_image.paste(reference_patches_pil[ref_index].resize((tar_width//2-2, tar_height//2-2)), (tar_width//2 * idx_x + 1, tar_height//2 * idx_y + 1))
443
+ gr.Info("Model inference in progress...")
444
+ generator = torch.Generator(device=device).manual_seed(seed)
445
+ image = pipeline(
446
+ "manga", cond_image=combined_image, cond_mask=validation_mask, num_inference_steps=num_inference_steps, generator=generator
447
+ ).images[0]
448
+ gr.Info("Post-processing image...")
449
+ with torch.no_grad():
450
+ width, height = image.size
451
+ new_width = width // 2
452
+ new_height = height // 2
453
+ left = (width - new_width) // 2
454
+ top = (height - new_height) // 2
455
+ right = left + new_width
456
+ bottom = top + new_height
457
+ center_crop = image.crop((left, top, right, bottom))
458
+ up_img = center_crop.resize(query_image_vae.size)
459
+ test_low_color = transform(up_img).unsqueeze(0).to(device, dtype=weight_dtype)
460
+ query_image_vae = transform(query_image_vae).unsqueeze(0).to(device, dtype=weight_dtype)
461
+
462
+ h_color, hidden_list_color = pipeline.vae._encode(test_low_color,return_dict = False, hidden_flag = True)
463
+ h_bw, hidden_list_bw = pipeline.vae._encode(query_image_vae, return_dict = False, hidden_flag = True)
464
+
465
+ hidden_list_double = [torch.cat((hidden_list_color[hidden_idx], hidden_list_bw[hidden_idx]), dim = 1) for hidden_idx in range(len(hidden_list_color))]
466
+
467
+
468
+ hidden_list = MultiResNetModel(hidden_list_double)
469
+ output = pipeline.vae._decode(h_color.sample(),return_dict = False, hidden_list = hidden_list)[0]
470
+
471
+ output[output > 1] = 1
472
+ output[output < -1] = -1
473
+ high_res_image = Image.fromarray(((output[0] * 0.5 + 0.5).permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8)).convert("RGB")
474
+ gr.Info("Colorization complete!")
475
+ torch.cuda.empty_cache()
476
+ return high_res_image, up_img, image, query_image_bw
477
+
478
+ with gr.Blocks() as demo:
479
+ gr.HTML(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
480
  """
481
+ <div style="text-align: center;">
482
+ <h1 style="text-align: center; font-size: 3em;">🎨 ColorFlow:</h1>
483
+ <h3 style="text-align: center; font-size: 1.8em;">Retrieval-Augmented Image Sequence Colorization</h3>
484
+ <p style="text-align: center; font-weight: bold;">
485
+ <a href="https://zhuang2002.github.io/ColorFlow/">Project Page</a> |
486
+ <a href="https://arxiv.org/abs/2412.11815">ArXiv Preprint</a> |
487
+ <a href="https://github.com/TencentARC/ColorFlow">GitHub Repository</a>
488
+ </p>
489
+ <p style="text-align: center; font-weight: bold;">
490
+ NOTE: Each time you switch the input style, the corresponding model will be reloaded, which may take some time. Please be patient.
491
+ </p>
492
+ <p style="text-align: left; font-size: 1.1em;">
493
+ Welcome to the demo of <strong>ColorFlow</strong>. Follow the steps below to explore the capabilities of our model:
494
+ </p>
495
+ </div>
496
+ <div style="text-align: left; margin: 0 auto;">
497
+ <ol style="font-size: 1.1em;">
498
+ <li>Choose input style: GrayImage(ScreenStyle)、Sketch with Shading or Sketch.</li>
499
+ <li>Upload your image: Use the 'Upload' button to select the image you want to colorize.</li>
500
+ <li>Preprocess the image: Click the 'Preprocess' button to decolorize the image.</li>
501
+ <li>Upload reference images: Upload multiple reference images to guide the colorization.</li>
502
+ <li>Set sampling parameters (optional): Adjust the settings and click the <b>Colorize</b> button.</li>
503
+ </ol>
504
+ <p>
505
+ ⏱️ <b>ZeroGPU Time Limit</b>: Hugging Face ZeroGPU has an inference time limit of 180 seconds. You may need to log in with a free account to use this demo. Large sampling steps might lead to timeout (GPU Abort). In that case, please consider logging in with a Pro account or running it on your local machine.
506
+ </p>
507
+ </div>
508
+ <div style="text-align: center;">
509
+ <p style="text-align: center; font-weight: bold;">
510
+ 注意:每次切换输入样式时,相应的模型将被重新加载,可能需要一些时间。请耐心等待。
511
+ </p>
512
+ <p style="text-align: left; font-size: 1.1em;">
513
+ 欢迎使��� <strong>ColorFlow</strong> 演示。请按照以下步骤探索我们模型的能力:
514
+ </p>
515
+ </div>
516
+ <div style="text-align: left; margin: 0 auto;">
517
+ <ol style="font-size: 1.1em;">
518
+ <li>选择输入样式:灰度图(ScreenStyle)、线稿+阴影、线稿。</li>
519
+ <li>上传您的图像:使用“上传”按钮选择要上色的图像。</li>
520
+ <li>预处理图像:点击“预处理”按钮以去色图像。</li>
521
+ <li>上传参考图像:上传多张参考图像以指导上色。</li>
522
+ <li>设置采样参数(可选):调整设置并点击 <b>上色</b> 按钮。</li>
523
+ </ol>
524
+ <p>
525
+ ⏱️ <b>ZeroGPU时间限制</b>:Hugging Face ZeroGPU 的推理时间限制为 180 秒。您可能需要使用免费帐户登录以使用此演示。大采样步骤可能会导致超时(GPU 中止)。在这种情况下,请考虑使用专业帐户登录或在本地计算机上运行。
526
+ </p>
527
+ </div>
528
  """
529
+ )
530
+ VAE_input = gr.State()
531
+ input_context = gr.State()
532
+ # example_loading = gr.State(value=None)
533
+
534
+ with gr.Column():
535
+ with gr.Row():
536
+ input_style = gr.Radio(["GrayImage(ScreenStyle)", "Sketch_Shading", "Sketch"], label="Input Style", value="GrayImage(ScreenStyle)")
537
+ with gr.Row():
538
+ with gr.Column():
539
+ input_image = gr.Image(type="pil", label="Image to Colorize")
540
+ resolution = gr.Radio(["640x640", "512x800", "800x512"], label="Select Resolution(Width*Height)", value="640x640")
541
+ extract_button = gr.Button("Preprocess (Decolorize)")
542
+ extracted_image = gr.Image(type="pil", label="Decolorized Result")
543
+ with gr.Row():
544
+ reference_images = gr.Files(label="Reference Images (Upload multiple)", file_count="multiple")
545
+ with gr.Column():
546
+ output_gallery = gr.Gallery(label="Colorization Results", type="pil")
547
+ seed = gr.Slider(label="Random Seed", minimum=0, maximum=100000, value=0, step=1)
548
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=4, maximum=100, value=10, step=1)
549
+ colorize_button = gr.Button("Colorize")
550
+
551
+ # progress_text = gr.Textbox(label="Progress", interactive=False)
552
+
553
+
554
+ extract_button.click(
555
+ extract_line_image,
556
+ inputs=[input_image, input_style, resolution],
557
+ outputs=[extracted_image, VAE_input, input_context]
558
+ )
559
+ colorize_button.click(
560
+ colorize_image,
561
+ inputs=[VAE_input, input_context, reference_images, resolution, seed, input_style, num_inference_steps],
562
+ outputs=output_gallery
563
+ )
564
+
565
+ with gr.Column():
566
+ gr.Markdown("### Quick Examples")
567
+ gr.Examples(
568
+ examples=examples,
569
+ inputs=[input_image, reference_images, input_style, resolution, seed, num_inference_steps],
570
+ label="Examples",
571
+ examples_per_page=8,
572
+ )
573
+ gr.HTML('<a href="https://github.com/TencentARC/ColorFlow"><img src="https://img.shields.io/github/stars/TencentARC/ColorFlow" alt="GitHub Stars"></a>')
574
+ gr.Markdown(article)
575
+ # gr.HTML(
576
+ # '<a href="https://github.com/TencentARC/ColorFlow"><img src="https://img.shields.io/github/stars/TencentARC/ColorFlow" alt="GitHub Stars"></a>'
577
+ # )
578
+
579
+
580
+ demo.launch()