Thompson001 commited on
Commit
8290d17
·
verified ·
1 Parent(s): 298eabc

Upload 5 files

Browse files
models/base_model.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from collections import OrderedDict
4
+ from abc import ABC, abstractmethod
5
+ from . import networks
6
+
7
+
8
+ class BaseModel(ABC):
9
+ """This class is an abstract base class (ABC) for models.
10
+ To create a subclass, you need to implement the following five functions:
11
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
12
+ -- <set_input>: unpack data from dataset and apply preprocessing.
13
+ -- <forward>: produce intermediate results.
14
+ -- <optimize_parameters>: calculate losses, gradients, and update network weights.
15
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
16
+ """
17
+
18
+ def __init__(self, opt):
19
+ """Initialize the BaseModel class.
20
+
21
+ Parameters:
22
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
23
+
24
+ When creating your custom class, you need to implement your own initialization.
25
+ In this fucntion, you should first call <BaseModel.__init__(self, opt)>
26
+ Then, you need to define four lists:
27
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
28
+ -- self.model_names (str list): specify the images that you want to display and save.
29
+ -- self.visual_names (str list): define networks used in our training.
30
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
31
+ """
32
+ self.opt = opt
33
+ self.gpu_ids = opt.gpu_ids
34
+ self.isTrain = opt.isTrain
35
+ self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
36
+ if hasattr(opt, 'checkpoints_dir'):
37
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
38
+ if not hasattr(opt, 'preprocess') or opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
39
+ torch.backends.cudnn.benchmark = True
40
+ self.loss_names = []
41
+ self.model_names = []
42
+ self.visual_names = []
43
+ self.optimizers = []
44
+ self.image_paths = []
45
+ self.metric = 0 # used for learning rate policy 'plateau'
46
+
47
+ @staticmethod
48
+ def modify_commandline_options(parser, is_train):
49
+ """Add new model-specific options, and rewrite default values for existing options.
50
+
51
+ Parameters:
52
+ parser -- original option parser
53
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
54
+
55
+ Returns:
56
+ the modified parser.
57
+ """
58
+ return parser
59
+
60
+ @abstractmethod
61
+ def set_input(self, input):
62
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
63
+
64
+ Parameters:
65
+ input (dict): includes the data itself and its metadata information.
66
+ """
67
+ pass
68
+
69
+ @abstractmethod
70
+ def forward(self):
71
+ """Run forward pass; called by both functions <optimize_parameters> and <test>."""
72
+ pass
73
+
74
+ @abstractmethod
75
+ def optimize_parameters(self):
76
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
77
+ pass
78
+
79
+ def setup(self, opt):
80
+ """Load and print networks; create schedulers
81
+
82
+ Parameters:
83
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
84
+ """
85
+ if self.isTrain:
86
+ self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
87
+ if not self.isTrain or opt.continue_train:
88
+ load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
89
+ self.load_networks(load_suffix)
90
+ self.print_networks(opt.verbose)
91
+
92
+ def eval(self):
93
+ """Make models eval mode during test time"""
94
+ for name in self.model_names:
95
+ if isinstance(name, str):
96
+ net = getattr(self, 'net' + name)
97
+ net.eval()
98
+
99
+ def test(self):
100
+ """Forward function used in test time.
101
+
102
+ This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
103
+ It also calls <compute_visuals> to produce additional visualization results
104
+ """
105
+ with torch.no_grad():
106
+ self.forward()
107
+ self.compute_visuals()
108
+
109
+ def compute_visuals(self):
110
+ """Calculate additional output images for visdom and HTML visualization"""
111
+ pass
112
+
113
+ def get_image_paths(self):
114
+ """ Return image paths that are used to load current data"""
115
+ return self.image_paths
116
+
117
+ def update_learning_rate(self):
118
+ """Update learning rates for all the networks; called at the end of every epoch"""
119
+ for scheduler in self.schedulers:
120
+ if self.opt.lr_policy == 'plateau':
121
+ scheduler.step(self.metric)
122
+ else:
123
+ scheduler.step()
124
+ lr = self.optimizers[0].param_groups[0]['lr']
125
+ print('learning rate = %.7f' % lr)
126
+
127
+ def get_current_visuals(self):
128
+ """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
129
+ visual_ret = OrderedDict()
130
+ for name in self.visual_names:
131
+ if isinstance(name, str):
132
+ visual_ret[name] = getattr(self, name)
133
+ return visual_ret
134
+
135
+ def get_current_losses(self):
136
+ """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
137
+ errors_ret = OrderedDict()
138
+ for name in self.loss_names:
139
+ if isinstance(name, str):
140
+ errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
141
+ return errors_ret
142
+
143
+ def save_networks(self, epoch):
144
+ """Save all the networks to the disk.
145
+
146
+ Parameters:
147
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
148
+ """
149
+ for name in self.model_names:
150
+ if isinstance(name, str):
151
+ save_filename = '%s_net_%s.pth' % (epoch, name)
152
+ save_path = os.path.join(self.save_dir, save_filename)
153
+ net = getattr(self, 'net' + name)
154
+
155
+ if len(self.gpu_ids) > 0 and torch.cuda.is_available():
156
+ torch.save(net.module.cpu().state_dict(), save_path)
157
+ net.cuda(self.gpu_ids[0])
158
+ else:
159
+ torch.save(net.cpu().state_dict(), save_path)
160
+
161
+ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
162
+ """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
163
+ key = keys[i]
164
+ if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
165
+ if module.__class__.__name__.startswith('InstanceNorm') and \
166
+ (key == 'running_mean' or key == 'running_var'):
167
+ if getattr(module, key) is None:
168
+ state_dict.pop('.'.join(keys))
169
+ if module.__class__.__name__.startswith('InstanceNorm') and \
170
+ (key == 'num_batches_tracked'):
171
+ state_dict.pop('.'.join(keys))
172
+ else:
173
+ self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
174
+
175
+ def load_networks(self, epoch):
176
+ """Load all the networks from the disk.
177
+
178
+ Parameters:
179
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
180
+ """
181
+ for name in self.model_names:
182
+ if isinstance(name, str):
183
+ load_filename = '%s_net_%s.pth' % (epoch, name)
184
+ load_path = os.path.join(self.save_dir, load_filename)
185
+ net = getattr(self, 'net' + name)
186
+ if isinstance(net, torch.nn.DataParallel):
187
+ net = net.module
188
+ print('loading the model from %s' % load_path)
189
+ # if you are using PyTorch newer than 0.4 (e.g., built from
190
+ # GitHub source), you can remove str() on self.device
191
+ state_dict = torch.load(load_path, map_location=str(self.device))
192
+ if hasattr(state_dict, '_metadata'):
193
+ del state_dict._metadata
194
+
195
+ # patch InstanceNorm checkpoints prior to 0.4
196
+ for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
197
+ self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
198
+ net.load_state_dict(state_dict, strict=False)
199
+
200
+ def print_networks(self, verbose):
201
+ """Print the total number of parameters in the network and (if verbose) network architecture
202
+
203
+ Parameters:
204
+ verbose (bool) -- if verbose: print the network architecture
205
+ """
206
+ print('---------- Networks initialized -------------')
207
+ for name in self.model_names:
208
+ if isinstance(name, str):
209
+ net = getattr(self, 'net' + name)
210
+ num_params = 0
211
+ for param in net.parameters():
212
+ num_params += param.numel()
213
+ if verbose:
214
+ print(net)
215
+ print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
216
+ print('-----------------------------------------------')
217
+
218
+ def set_requires_grad(self, nets, requires_grad=False):
219
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
220
+ Parameters:
221
+ nets (network list) -- a list of networks
222
+ requires_grad (bool) -- whether the networks require gradients or not
223
+ """
224
+ if not isinstance(nets, list):
225
+ nets = [nets]
226
+ for net in nets:
227
+ if net is not None:
228
+ for param in net.parameters():
229
+ param.requires_grad = requires_grad
models/deepcrack_networks.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! -*- coding: utf-8 -*-
2
+ # Author: Yahui Liu <yahui.liu@unitn.it>
3
+
4
+ """
5
+ Reference:
6
+
7
+ DeepCrack: A deep hierarchical feature learning architecture for crack segmentation.
8
+ https://www.sciencedirect.com/science/article/pii/S0925231219300566
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from .networks import get_norm_layer, init_net
15
+
16
+ class DeepCrackNet(nn.Module):
17
+ def __init__(self, in_nc, num_classes, ngf, norm='batch'):
18
+ super(DeepCrackNet, self).__init__()
19
+
20
+ norm_layer = get_norm_layer(norm_type=norm)
21
+ self.conv1 = nn.Sequential(*self._conv_block(in_nc, ngf, norm_layer, num_block=2))
22
+ self.side_conv1 = nn.Conv2d(ngf, num_classes, kernel_size=1, stride=1, bias=False)
23
+
24
+ self.conv2 = nn.Sequential(*self._conv_block(ngf, ngf*2, norm_layer, num_block=2))
25
+ self.side_conv2 = nn.Conv2d(ngf*2, num_classes, kernel_size=1, stride=1, bias=False)
26
+
27
+ self.conv3 = nn.Sequential(*self._conv_block(ngf*2, ngf*4, norm_layer, num_block=3))
28
+ self.side_conv3 = nn.Conv2d(ngf*4, num_classes, kernel_size=1, stride=1, bias=False)
29
+
30
+ self.conv4 = nn.Sequential(*self._conv_block(ngf*4, ngf*8, norm_layer, num_block=3))
31
+ self.side_conv4 = nn.Conv2d(ngf*8, num_classes, kernel_size=1, stride=1, bias=False)
32
+
33
+ self.conv5 = nn.Sequential(*self._conv_block(ngf*8, ngf*8, norm_layer, num_block=3))
34
+ self.side_conv5 = nn.Conv2d(ngf*8, num_classes, kernel_size=1, stride=1, bias=False)
35
+
36
+ self.fuse_conv = nn.Conv2d(num_classes*5, num_classes, kernel_size=1, stride=1, bias=False)
37
+ self.maxpool = nn.MaxPool2d(2, stride=2)
38
+
39
+ #self.up2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
40
+ #self.up4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
41
+ #self.up8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
42
+ #self.up16 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True)
43
+
44
+ def _conv_block(self, in_nc, out_nc, norm_layer, num_block=2, kernel_size=3,
45
+ stride=1, padding=1, bias=False):
46
+ conv = []
47
+ for i in range(num_block):
48
+ cur_in_nc = in_nc if i == 0 else out_nc
49
+ conv += [nn.Conv2d(cur_in_nc, out_nc, kernel_size=kernel_size, stride=stride,
50
+ padding=padding, bias=bias),
51
+ norm_layer(out_nc),
52
+ nn.ReLU(True)]
53
+ return conv
54
+
55
+ def forward(self, x):
56
+ h,w = x.size()[2:]
57
+ # main stream features
58
+ conv1 = self.conv1(x)
59
+ conv2 = self.conv2(self.maxpool(conv1))
60
+ conv3 = self.conv3(self.maxpool(conv2))
61
+ conv4 = self.conv4(self.maxpool(conv3))
62
+ conv5 = self.conv5(self.maxpool(conv4))
63
+ # side output features
64
+ side_output1 = self.side_conv1(conv1)
65
+ side_output2 = self.side_conv2(conv2)
66
+ side_output3 = self.side_conv3(conv3)
67
+ side_output4 = self.side_conv4(conv4)
68
+ side_output5 = self.side_conv5(conv5)
69
+ # upsampling side output features
70
+ side_output2 = F.interpolate(side_output2, size=(h, w), mode='bilinear', align_corners=True) #self.up2(side_output2)
71
+ side_output3 = F.interpolate(side_output3, size=(h, w), mode='bilinear', align_corners=True) #self.up4(side_output3)
72
+ side_output4 = F.interpolate(side_output4, size=(h, w), mode='bilinear', align_corners=True) #self.up8(side_output4)
73
+ side_output5 = F.interpolate(side_output5, size=(h, w), mode='bilinear', align_corners=True) #self.up16(side_output5)
74
+
75
+ fused = self.fuse_conv(torch.cat([side_output1,
76
+ side_output2,
77
+ side_output3,
78
+ side_output4,
79
+ side_output5], dim=1))
80
+ return side_output1, side_output2, side_output3, side_output4, side_output5, fused
81
+
82
+ def define_deepcrack(in_nc,
83
+ num_classes,
84
+ ngf,
85
+ norm='batch',
86
+ init_type='xavier',
87
+ init_gain=0.02,
88
+ gpu_ids=[]):
89
+ net = DeepCrackNet(in_nc, num_classes, ngf, norm)
90
+ return init_net(net, init_type, init_gain, gpu_ids)
91
+
92
+
93
+ class BinaryFocalLoss(nn.Module):
94
+ def __init__(self, alpha=1, gamma=2, logits=False, size_average=True):
95
+ super(BinaryFocalLoss, self).__init__()
96
+ self.alpha = alpha
97
+ self.gamma = gamma
98
+ self.logits = logits
99
+ self.size_average = size_average
100
+ self.criterion = nn.BCEWithLogitsLoss(reduction='none')
101
+
102
+ def forward(self, inputs, targets):
103
+ BCE_loss = self.criterion(inputs, targets)
104
+ pt = torch.exp(-BCE_loss)
105
+ F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
106
+
107
+ if self.size_average:
108
+ return F_loss.mean()
109
+ else:
110
+ return F_loss.sum()
models/networks.py ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import init
4
+ import functools
5
+ from torch.optim import lr_scheduler
6
+
7
+
8
+ ###############################################################################
9
+ # Helper Functions
10
+ ###############################################################################
11
+ def get_norm_layer(norm_type='instance'):
12
+ """Return a normalization layer
13
+
14
+ Parameters:
15
+ norm_type (str) -- the name of the normalization layer: batch | instance | none
16
+
17
+ For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
18
+ For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
19
+ """
20
+ if norm_type == 'batch':
21
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
22
+ elif norm_type == 'instance':
23
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
24
+ elif norm_type == 'none':
25
+ norm_layer = None
26
+ else:
27
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
28
+ return norm_layer
29
+
30
+
31
+ def get_scheduler(optimizer, opt):
32
+ """Return a learning rate scheduler
33
+
34
+ Parameters:
35
+ optimizer -- the optimizer of the network
36
+ opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 
37
+ opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
38
+
39
+ For 'linear', we keep the same learning rate for the first <opt.niter> epochs
40
+ and linearly decay the rate to zero over the next <opt.niter_decay> epochs.
41
+ For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
42
+ See https://pytorch.org/docs/stable/optim.html for more details.
43
+ """
44
+ if opt.lr_policy == 'linear':
45
+ def lambda_rule(epoch):
46
+ lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
47
+ return lr_l
48
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
49
+ elif opt.lr_policy == 'step':
50
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
51
+ elif opt.lr_policy == 'plateau':
52
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
53
+ elif opt.lr_policy == 'cosine':
54
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
55
+ else:
56
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
57
+ return scheduler
58
+
59
+
60
+ def init_weights(net, init_type='normal', init_gain=0.02):
61
+ """Initialize network weights.
62
+
63
+ Parameters:
64
+ net (network) -- network to be initialized
65
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
66
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
67
+
68
+ We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
69
+ work better for some applications. Feel free to try yourself.
70
+ """
71
+ def init_func(m): # define the initialization function
72
+ classname = m.__class__.__name__
73
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
74
+ if init_type == 'normal':
75
+ init.normal_(m.weight.data, 0.0, init_gain)
76
+ elif init_type == 'xavier':
77
+ init.xavier_normal_(m.weight.data, gain=init_gain)
78
+ elif init_type == 'kaiming':
79
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
80
+ elif init_type == 'orthogonal':
81
+ init.orthogonal_(m.weight.data, gain=init_gain)
82
+ else:
83
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
84
+ if hasattr(m, 'bias') and m.bias is not None:
85
+ init.constant_(m.bias.data, 0.0)
86
+ elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
87
+ init.normal_(m.weight.data, 1.0, init_gain)
88
+ init.constant_(m.bias.data, 0.0)
89
+
90
+ print('initialize network with %s' % init_type)
91
+ net.apply(init_func) # apply the initialization function <init_func>
92
+
93
+
94
+ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
95
+ """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
96
+ Parameters:
97
+ net (network) -- the network to be initialized
98
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
99
+ gain (float) -- scaling factor for normal, xavier and orthogonal.
100
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
101
+
102
+ Return an initialized network.
103
+ """
104
+ if len(gpu_ids) > 0:
105
+ assert(torch.cuda.is_available())
106
+ net.to(gpu_ids[0])
107
+ net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
108
+ init_weights(net, init_type, init_gain=init_gain)
109
+ return net
110
+
111
+
112
+ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
113
+ """Create a generator
114
+
115
+ Parameters:
116
+ input_nc (int) -- the number of channels in input images
117
+ output_nc (int) -- the number of channels in output images
118
+ ngf (int) -- the number of filters in the last conv layer
119
+ netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
120
+ norm (str) -- the name of normalization layers used in the network: batch | instance | none
121
+ use_dropout (bool) -- if use dropout layers.
122
+ init_type (str) -- the name of our initialization method.
123
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
124
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
125
+
126
+ Returns a generator
127
+
128
+ Our current implementation provides two types of generators:
129
+ U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
130
+ The original U-Net paper: https://arxiv.org/abs/1505.04597
131
+
132
+ Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
133
+ Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
134
+ We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
135
+
136
+
137
+ The generator has been initialized by <init_net>. It uses RELU for non-linearity.
138
+ """
139
+ net = None
140
+ norm_layer = get_norm_layer(norm_type=norm)
141
+
142
+ if netG == 'resnet_9blocks':
143
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
144
+ elif netG == 'resnet_6blocks':
145
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
146
+ elif netG == 'unet_128':
147
+ net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
148
+ elif netG == 'unet_256':
149
+ net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
150
+ else:
151
+ raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
152
+ return init_net(net, init_type, init_gain, gpu_ids)
153
+
154
+
155
+ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
156
+ """Create a discriminator
157
+
158
+ Parameters:
159
+ input_nc (int) -- the number of channels in input images
160
+ ndf (int) -- the number of filters in the first conv layer
161
+ netD (str) -- the architecture's name: basic | n_layers | pixel
162
+ n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
163
+ norm (str) -- the type of normalization layers used in the network.
164
+ init_type (str) -- the name of the initialization method.
165
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
166
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
167
+
168
+ Returns a discriminator
169
+
170
+ Our current implementation provides three types of discriminators:
171
+ [basic]: 'PatchGAN' classifier described in the original pix2pix paper.
172
+ It can classify whether 70×70 overlapping patches are real or fake.
173
+ Such a patch-level discriminator architecture has fewer parameters
174
+ than a full-image discriminator and can work on arbitrarily-sized images
175
+ in a fully convolutional fashion.
176
+
177
+ [n_layers]: With this mode, you cna specify the number of conv layers in the discriminator
178
+ with the parameter <n_layers_D> (default=3 as used in [basic] (PatchGAN).)
179
+
180
+ [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
181
+ It encourages greater color diversity but has no effect on spatial statistics.
182
+
183
+ The discriminator has been initialized by <init_net>. It uses Leakly RELU for non-linearity.
184
+ """
185
+ net = None
186
+ norm_layer = get_norm_layer(norm_type=norm)
187
+
188
+ if netD == 'basic': # default PatchGAN classifier
189
+ net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
190
+ elif netD == 'n_layers': # more options
191
+ net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
192
+ elif netD == 'pixel': # classify if each pixel is real or fake
193
+ net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
194
+ else:
195
+ raise NotImplementedError('Discriminator model name [%s] is not recognized' % net)
196
+ return init_net(net, init_type, init_gain, gpu_ids)
197
+
198
+
199
+ ##############################################################################
200
+ # Classes
201
+ ##############################################################################
202
+ class GANLoss(nn.Module):
203
+ """Define different GAN objectives.
204
+
205
+ The GANLoss class abstracts away the need to create the target label tensor
206
+ that has the same size as the input.
207
+ """
208
+
209
+ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
210
+ """ Initialize the GANLoss class.
211
+
212
+ Parameters:
213
+ gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
214
+ target_real_label (bool) - - label for a real image
215
+ target_fake_label (bool) - - label of a fake image
216
+
217
+ Note: Do not use sigmoid as the last layer of Discriminator.
218
+ LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
219
+ """
220
+ super(GANLoss, self).__init__()
221
+ self.register_buffer('real_label', torch.tensor(target_real_label))
222
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
223
+ self.gan_mode = gan_mode
224
+ if gan_mode == 'lsgan':
225
+ self.loss = nn.MSELoss()
226
+ elif gan_mode == 'vanilla':
227
+ self.loss = nn.BCEWithLogitsLoss()
228
+ elif gan_mode in ['wgangp']:
229
+ self.loss = None
230
+ else:
231
+ raise NotImplementedError('gan mode %s not implemented' % gan_mode)
232
+
233
+ def get_target_tensor(self, prediction, target_is_real):
234
+ """Create label tensors with the same size as the input.
235
+
236
+ Parameters:
237
+ prediction (tensor) - - tpyically the prediction from a discriminator
238
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
239
+
240
+ Returns:
241
+ A label tensor filled with ground truth label, and with the size of the input
242
+ """
243
+
244
+ if target_is_real:
245
+ target_tensor = self.real_label
246
+ else:
247
+ target_tensor = self.fake_label
248
+ return target_tensor.expand_as(prediction)
249
+
250
+ def __call__(self, prediction, target_is_real):
251
+ """Calculate loss given Discriminator's output and grount truth labels.
252
+
253
+ Parameters:
254
+ prediction (tensor) - - tpyically the prediction output from a discriminator
255
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
256
+
257
+ Returns:
258
+ the calculated loss.
259
+ """
260
+ if self.gan_mode in ['lsgan', 'vanilla']:
261
+ target_tensor = self.get_target_tensor(prediction, target_is_real)
262
+ loss = self.loss(prediction, target_tensor)
263
+ elif self.gan_mode == 'wgangp':
264
+ if target_is_real:
265
+ loss = -prediction.mean()
266
+ else:
267
+ loss = prediction.mean()
268
+ return loss
269
+
270
+
271
+ def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
272
+ """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
273
+
274
+ Arguments:
275
+ netD (network) -- discriminator network
276
+ real_data (tensor array) -- real images
277
+ fake_data (tensor array) -- generated images from the generator
278
+ device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
279
+ type (str) -- if we mix real and fake data or not [real | fake | mixed].
280
+ constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
281
+ lambda_gp (float) -- weight for this loss
282
+
283
+ Returns the gradient penalty loss
284
+ """
285
+ if lambda_gp > 0.0:
286
+ if type == 'real': # either use real images, fake images, or a linear interpolation of two.
287
+ interpolatesv = real_data
288
+ elif type == 'fake':
289
+ interpolatesv = fake_data
290
+ elif type == 'mixed':
291
+ alpha = torch.rand(real_data.shape[0], 1)
292
+ alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
293
+ alpha = alpha.to(device)
294
+ interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
295
+ else:
296
+ raise NotImplementedError('{} not implemented'.format(type))
297
+ interpolatesv.requires_grad_(True)
298
+ disc_interpolates = netD(interpolatesv)
299
+ gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
300
+ grad_outputs=torch.ones(disc_interpolates.size()).to(device),
301
+ create_graph=True, retain_graph=True, only_inputs=True)
302
+ gradients = gradients[0].view(real_data.size(0), -1) # flat the data
303
+ gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
304
+ return gradient_penalty, gradients
305
+ else:
306
+ return 0.0, None
307
+
308
+
309
+ class ResnetGenerator(nn.Module):
310
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
311
+
312
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
313
+ """
314
+
315
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
316
+ """Construct a Resnet-based generator
317
+
318
+ Parameters:
319
+ input_nc (int) -- the number of channels in input images
320
+ output_nc (int) -- the number of channels in output images
321
+ ngf (int) -- the number of filters in the last conv layer
322
+ norm_layer -- normalization layer
323
+ use_dropout (bool) -- if use dropout layers
324
+ n_blocks (int) -- the number of ResNet blocks
325
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
326
+ """
327
+ assert(n_blocks >= 0)
328
+ super(ResnetGenerator, self).__init__()
329
+ if type(norm_layer) == functools.partial:
330
+ use_bias = norm_layer.func == nn.InstanceNorm2d
331
+ else:
332
+ use_bias = norm_layer == nn.InstanceNorm2d
333
+
334
+ model = [nn.ReflectionPad2d(3),
335
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
336
+ norm_layer(ngf),
337
+ nn.ReLU(True)]
338
+
339
+ n_downsampling = 2
340
+ for i in range(n_downsampling): # add downsampling layers
341
+ mult = 2 ** i
342
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
343
+ norm_layer(ngf * mult * 2),
344
+ nn.ReLU(True)]
345
+
346
+ mult = 2 ** n_downsampling
347
+ for i in range(n_blocks): # add ResNet blocks
348
+
349
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
350
+
351
+ for i in range(n_downsampling): # add upsampling layers
352
+ mult = 2 ** (n_downsampling - i)
353
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
354
+ kernel_size=3, stride=2,
355
+ padding=1, output_padding=1,
356
+ bias=use_bias),
357
+ norm_layer(int(ngf * mult / 2)),
358
+ nn.ReLU(True)]
359
+ model += [nn.ReflectionPad2d(3)]
360
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
361
+ model += [nn.Tanh()]
362
+
363
+ self.model = nn.Sequential(*model)
364
+
365
+ def forward(self, input):
366
+ """Standard forward"""
367
+ return self.model(input)
368
+
369
+
370
+ class ResnetBlock(nn.Module):
371
+ """Define a Resnet block"""
372
+
373
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
374
+ """Initialize the Resnet block
375
+
376
+ A resnet block is a conv block with skip connections
377
+ We construct a conv block with build_conv_block function,
378
+ and implement skip connections in <forward> function.
379
+ Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
380
+ """
381
+ super(ResnetBlock, self).__init__()
382
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
383
+
384
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
385
+ """Construct a convolutional block.
386
+
387
+ Parameters:
388
+ dim (int) -- the number of channels in the conv layer.
389
+ padding_type (str) -- the name of padding layer: reflect | replicate | zero
390
+ norm_layer -- normalization layer
391
+ use_dropout (bool) -- if use dropout layers.
392
+ use_bias (bool) -- if the conv layer uses bias or not
393
+
394
+ Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
395
+ """
396
+ conv_block = []
397
+ p = 0
398
+ if padding_type == 'reflect':
399
+ conv_block += [nn.ReflectionPad2d(1)]
400
+ elif padding_type == 'replicate':
401
+ conv_block += [nn.ReplicationPad2d(1)]
402
+ elif padding_type == 'zero':
403
+ p = 1
404
+ else:
405
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
406
+
407
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
408
+ if use_dropout:
409
+ conv_block += [nn.Dropout(0.5)]
410
+
411
+ p = 0
412
+ if padding_type == 'reflect':
413
+ conv_block += [nn.ReflectionPad2d(1)]
414
+ elif padding_type == 'replicate':
415
+ conv_block += [nn.ReplicationPad2d(1)]
416
+ elif padding_type == 'zero':
417
+ p = 1
418
+ else:
419
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
420
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
421
+
422
+ return nn.Sequential(*conv_block)
423
+
424
+ def forward(self, x):
425
+ """Forward function (with skip connections)"""
426
+ out = x + self.conv_block(x) # add skip connections
427
+ return out
428
+
429
+
430
+ class UnetGenerator(nn.Module):
431
+ """Create a Unet-based generator"""
432
+
433
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
434
+ """Construct a Unet generator
435
+ Parameters:
436
+ input_nc (int) -- the number of channels in input images
437
+ output_nc (int) -- the number of channels in output images
438
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
439
+ image of size 128x128 will become of size 1x1 # at the bottleneck
440
+ ngf (int) -- the number of filters in the last conv layer
441
+ norm_layer -- normalization layer
442
+
443
+ We construct the U-Net from the innermost layer to the outermost layer.
444
+ It is a recursive process.
445
+ """
446
+ super(UnetGenerator, self).__init__()
447
+ # construct unet structure
448
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
449
+ for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
450
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
451
+ # gradually reduce the number of filters from ngf * 8 to ngf
452
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
453
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
454
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
455
+ self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
456
+
457
+ def forward(self, input):
458
+ """Standard forward"""
459
+ return self.model(input)
460
+
461
+
462
+ class UnetSkipConnectionBlock(nn.Module):
463
+ """Defines the Unet submodule with skip connection.
464
+ X -------------------identity----------------------
465
+ |-- downsampling -- |submodule| -- upsampling --|
466
+ """
467
+
468
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
469
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
470
+ """Construct a Unet submodule with skip connections.
471
+
472
+ Parameters:
473
+ outer_nc (int) -- the number of filters in the outer conv layer
474
+ inner_nc (int) -- the number of filters in the inner conv layer
475
+ input_nc (int) -- the number of channels in input images/features
476
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
477
+ outermost (bool) -- if this module is the outermost module
478
+ innermost (bool) -- if this module is the innermost module
479
+ norm_layer -- normalization layer
480
+ user_dropout (bool) -- if use dropout layers.
481
+ """
482
+ super(UnetSkipConnectionBlock, self).__init__()
483
+ self.outermost = outermost
484
+ if type(norm_layer) == functools.partial:
485
+ use_bias = norm_layer.func == nn.InstanceNorm2d
486
+ else:
487
+ use_bias = norm_layer == nn.InstanceNorm2d
488
+ if input_nc is None:
489
+ input_nc = outer_nc
490
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
491
+ stride=2, padding=1, bias=use_bias)
492
+ downrelu = nn.LeakyReLU(0.2, True)
493
+ downnorm = norm_layer(inner_nc)
494
+ uprelu = nn.ReLU(True)
495
+ upnorm = norm_layer(outer_nc)
496
+
497
+ if outermost:
498
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
499
+ kernel_size=4, stride=2,
500
+ padding=1)
501
+ down = [downconv]
502
+ up = [uprelu, upconv, nn.Tanh()]
503
+ model = down + [submodule] + up
504
+ elif innermost:
505
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
506
+ kernel_size=4, stride=2,
507
+ padding=1, bias=use_bias)
508
+ down = [downrelu, downconv]
509
+ up = [uprelu, upconv, upnorm]
510
+ model = down + up
511
+ else:
512
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
513
+ kernel_size=4, stride=2,
514
+ padding=1, bias=use_bias)
515
+ down = [downrelu, downconv, downnorm]
516
+ up = [uprelu, upconv, upnorm]
517
+
518
+ if use_dropout:
519
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
520
+ else:
521
+ model = down + [submodule] + up
522
+
523
+ self.model = nn.Sequential(*model)
524
+
525
+ def forward(self, x):
526
+ if self.outermost:
527
+ return self.model(x)
528
+ else: # add skip connections
529
+ return torch.cat([x, self.model(x)], 1)
530
+
531
+
532
+ class NLayerDiscriminator(nn.Module):
533
+ """Defines a PatchGAN discriminator"""
534
+
535
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
536
+ """Construct a PatchGAN discriminator
537
+
538
+ Parameters:
539
+ input_nc (int) -- the number of channels in input images
540
+ ndf (int) -- the number of filters in the last conv layer
541
+ n_layers (int) -- the number of conv layers in the discriminator
542
+ norm_layer -- normalization layer
543
+ """
544
+ super(NLayerDiscriminator, self).__init__()
545
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
546
+ use_bias = norm_layer.func != nn.BatchNorm2d
547
+ else:
548
+ use_bias = norm_layer != nn.BatchNorm2d
549
+
550
+ kw = 4
551
+ padw = 1
552
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
553
+ nf_mult = 1
554
+ nf_mult_prev = 1
555
+ for n in range(1, n_layers): # gradually increase the number of filters
556
+ nf_mult_prev = nf_mult
557
+ nf_mult = min(2 ** n, 8)
558
+ sequence += [
559
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
560
+ norm_layer(ndf * nf_mult),
561
+ nn.LeakyReLU(0.2, True)
562
+ ]
563
+
564
+ nf_mult_prev = nf_mult
565
+ nf_mult = min(2 ** n_layers, 8)
566
+ sequence += [
567
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
568
+ norm_layer(ndf * nf_mult),
569
+ nn.LeakyReLU(0.2, True)
570
+ ]
571
+
572
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
573
+ self.model = nn.Sequential(*sequence)
574
+
575
+ def forward(self, input):
576
+ """Standard forward."""
577
+ return self.model(input)
578
+
579
+
580
+ class PixelDiscriminator(nn.Module):
581
+ """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
582
+
583
+ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
584
+ """Construct a 1x1 PatchGAN discriminator
585
+
586
+ Parameters:
587
+ input_nc (int) -- the number of channels in input images
588
+ ndf (int) -- the number of filters in the last conv layer
589
+ norm_layer -- normalization layer
590
+ """
591
+ super(PixelDiscriminator, self).__init__()
592
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
593
+ use_bias = norm_layer.func != nn.InstanceNorm2d
594
+ else:
595
+ use_bias = norm_layer != nn.InstanceNorm2d
596
+
597
+ self.net = [
598
+ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
599
+ nn.LeakyReLU(0.2, True),
600
+ nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
601
+ norm_layer(ndf * 2),
602
+ nn.LeakyReLU(0.2, True),
603
+ nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
604
+
605
+ self.net = nn.Sequential(*self.net)
606
+
607
+ def forward(self, input):
608
+ """Standard forward."""
609
+ return self.net(input)
models/roadnet_model.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Author: Yahui Liu <yahui.liu@uintn.it>
2
+
3
+ import torch
4
+ import numpy as np
5
+ import itertools
6
+ from .base_model import BaseModel
7
+ import torch.nn.functional as F
8
+ from .roadnet_networks import define_roadnet
9
+
10
+ class RoadNetModel(BaseModel):
11
+ """
12
+ This class implements the RoadNet model.
13
+ RoadNet paper: https://ieeexplore.ieee.org/document/8506600
14
+ """
15
+ @staticmethod
16
+ def modify_commandline_options(parser, is_train=True):
17
+ """Add new dataset-specific options, and rewrite default values for existing options."""
18
+ return parser
19
+
20
+ def __init__(self, opt):
21
+ """Initialize the RoadNet class.
22
+
23
+ Parameters:
24
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
25
+ """
26
+ BaseModel.__init__(self, opt)
27
+ # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
28
+ self.loss_names = ['segment', 'edge', 'centerline']
29
+ # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
30
+ self.visual_names = ['image', 'label_gt', 'label_pred']
31
+ # specify the models you want to save to the disk.
32
+ self.model_names = ['G']
33
+
34
+ # define networks
35
+ self.netG = define_roadnet(opt.input_nc,
36
+ opt.output_nc,
37
+ opt.ngf,
38
+ opt.norm,
39
+ opt.use_selu,
40
+ opt.init_type,
41
+ opt.init_gain,
42
+ self.gpu_ids)
43
+
44
+ if self.isTrain:
45
+ # define loss functions
46
+ self.weight_segment_side = [0.5, 0.75, 1.0, 0.75, 0.5, 1.0]
47
+ self.weight_others_side = [0.5, 0.75, 1.0, 0.75, 1.0]
48
+
49
+ # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
50
+ self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, eps=1e-3, weight_decay=2e-4)
51
+ #self.optimizer = torch.optim.SGD(self.netG.parameters(), lr=opt.lr, momentum=0.9, weight_decay=2e-4)
52
+ self.optimizers.append(self.optimizer)
53
+
54
+ def _get_balanced_sigmoid_cross_entropy(self,x):
55
+ count_neg = torch.sum(1. - x)
56
+ count_pos = torch.sum(x)
57
+ beta = count_neg / (count_neg + count_pos)
58
+ pos_weight = beta / (1 - beta)
59
+ cost = torch.nn.BCEWithLogitsLoss(size_average=True, reduce=True, pos_weight=pos_weight)
60
+ return cost, 1-beta
61
+
62
+ def set_input(self, input):
63
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
64
+
65
+ Parameters:
66
+ input (dict): include the data itself and its metadata information.
67
+ """
68
+ self.image = input['image'].to(self.device)
69
+ self.segment_gt = input['segment'].to(self.device)
70
+ self.edge_gt = input['edge'].to(self.device)
71
+ self.centerline_gt = input['centerline'].to(self.device)
72
+ self.image_paths = input['A_paths']
73
+ if self.isTrain:
74
+ self.criterion_seg, self.beta_seg = self._get_balanced_sigmoid_cross_entropy(self.segment_gt)
75
+ self.criterion_edg, self.beta_edg = self._get_balanced_sigmoid_cross_entropy(self.edge_gt)
76
+ self.criterion_cnt, self.beta_cnt = self._get_balanced_sigmoid_cross_entropy(self.centerline_gt)
77
+
78
+ def forward(self):
79
+ """Run forward pass; called by both functions <optimize_parameters> and <test>."""
80
+ self.segments, self.edges, self.centerlines = self.netG(self.image)
81
+
82
+ # for visualization
83
+ segment_gt_viz = (self.segment_gt-0.5)/0.5
84
+ edge_gt_viz = (self.edge_gt-0.5)/0.5
85
+ centerline_gt_viz = (self.centerline_gt-0.5)/0.5
86
+ self.label_gt = torch.cat([centerline_gt_viz, edge_gt_viz, segment_gt_viz], dim=1)
87
+
88
+ segment_fused = (torch.sigmoid(self.segments[-1])-0.5)/0.5
89
+ edge_fused = (torch.sigmoid(self.edges[-1])-0.5)/0.5
90
+ centerlines_fused = (torch.sigmoid(self.centerlines[-1])-0.5)/0.5
91
+ self.label_pred = torch.cat([centerlines_fused, edge_fused, segment_fused], dim=1)
92
+
93
+ def backward(self):
94
+ """Calculate the loss"""
95
+ self.loss_segment = torch.mean((torch.sigmoid(self.segments[-1])-self.segment_gt)**2) * 0.5
96
+ if self.segment_gt.sum() > 0.0: # ignore blank ones
97
+ for out, w in zip(self.segments, self.weight_segment_side):
98
+ self.loss_segment += self.criterion_seg(out, self.segment_gt) * self.beta_seg * w
99
+
100
+ self.loss_edge = torch.mean((torch.sigmoid(self.edges[-1])-self.edge_gt)**2) * 0.5
101
+ if self.edge_gt.sum() > 0.0:
102
+ for out, w in zip(self.edges, self.weight_others_side):
103
+ self.loss_edge += self.criterion_edg(out, self.edge_gt) * self.beta_edg * w
104
+
105
+ self.loss_centerline = torch.mean((torch.sigmoid(self.centerlines[-1])-self.centerline_gt)**2) * 0.5
106
+ if self.centerline_gt.sum() > 0.0:
107
+ for out, w in zip(self.centerlines, self.weight_others_side):
108
+ self.loss_centerline += self.criterion_cnt(out, self.centerline_gt) * self.beta_cnt * w
109
+
110
+ self.loss_total = self.loss_segment + self.loss_edge + self.loss_centerline
111
+ self.loss_total.backward()
112
+
113
+ def optimize_parameters(self, epoch=None):
114
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
115
+
116
+ # forward
117
+ self.forward() # compute predictions.
118
+ self.optimizer.zero_grad() # set G's gradients to zero
119
+ self.backward() # calculate gradients for G
120
+ self.optimizer.step() # update G's weights
models/roadnet_networks.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! -*- coding: utf-8 -*-
2
+ # Author: Yahui Liu <yahui.liu@unitn.it>
3
+
4
+ """
5
+ Reference:
6
+
7
+ RoadNet: Learning to Comprehensively Analyze Road Networks in Complex Urban Scenes
8
+ From High-Resolution Remotely Sensed Images.
9
+ https://ieeexplore.ieee.org/document/8506600
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from .networks import get_norm_layer, init_net
16
+
17
+ class RoadNet(nn.Module):
18
+ def __init__(self, in_nc, out_nc, ngf, norm='batch', use_selu=1):
19
+ super(RoadNet, self).__init__()
20
+ norm_layer = get_norm_layer(norm_type=norm)
21
+
22
+ #------------road surface segmentation------------#
23
+ self.segment_conv1 = nn.Sequential(*self._conv_block(in_nc, ngf, norm_layer, use_selu, num_block=2))
24
+ self.side_segment_conv1 = nn.Conv2d(ngf, out_nc, kernel_size=1, stride=1, bias=False)
25
+
26
+ self.segment_conv2 = nn.Sequential(*self._conv_block(ngf, ngf*2, norm_layer, use_selu, num_block=2))
27
+ self.side_segment_conv2 = nn.Conv2d(ngf*2, out_nc, kernel_size=1, stride=1, bias=False)
28
+
29
+ self.segment_conv3 = nn.Sequential(*self._conv_block(ngf*2, ngf*4, norm_layer, use_selu, num_block=3))
30
+ self.side_segment_conv3 = nn.Conv2d(ngf*4, out_nc, kernel_size=1, stride=1, bias=False)
31
+
32
+ self.segment_conv4 = nn.Sequential(*self._conv_block(ngf*4, ngf*8, norm_layer, use_selu, num_block=3))
33
+ self.side_segment_conv4 = nn.Conv2d(ngf*8, out_nc, kernel_size=1, stride=1, bias=False)
34
+
35
+ self.segment_conv5 = nn.Sequential(*self._conv_block(ngf*8, ngf*8, norm_layer, use_selu, num_block=3))
36
+ self.side_segment_conv5 = nn.Conv2d(ngf*8, out_nc, kernel_size=1, stride=1, bias=False)
37
+
38
+ self.fuse_segment_conv = nn.Conv2d(out_nc*5, out_nc, kernel_size=1, stride=1, bias=False)
39
+
40
+ ngf2 = ngf//2
41
+ #------------road edge detection------------#
42
+ self.edge_conv1 = nn.Sequential(*self._conv_block(in_nc+out_nc, ngf2, norm_layer, use_selu, num_block=2))
43
+ self.side_edge_conv1 = nn.Conv2d(ngf2, out_nc, kernel_size=1, stride=1, bias=False)
44
+
45
+ self.edge_conv2 = nn.Sequential(*self._conv_block(ngf2, ngf2*2, norm_layer, use_selu, num_block=2))
46
+ self.side_edge_conv2 = nn.Conv2d(ngf2*2, out_nc, kernel_size=1, stride=1, bias=False)
47
+
48
+ self.edge_conv3 = nn.Sequential(*self._conv_block(ngf2*2, ngf2*4, norm_layer, use_selu, num_block=2))
49
+ self.side_edge_conv3 = nn.Conv2d(ngf2*4, out_nc, kernel_size=1, stride=1, bias=False)
50
+
51
+ self.edge_conv4 = nn.Sequential(*self._conv_block(ngf2*4, ngf2*8, norm_layer, use_selu, num_block=2))
52
+ self.side_edge_conv4 = nn.Conv2d(ngf2*8, out_nc, kernel_size=1, stride=1, bias=False)
53
+
54
+ self.fuse_edge_conv = nn.Conv2d(out_nc*4, out_nc, kernel_size=1, stride=1, bias=False)
55
+
56
+ #------------road centerline extraction------------#
57
+ self.centerline_conv1 = nn.Sequential(*self._conv_block(in_nc+out_nc, ngf2, norm_layer, use_selu, num_block=2))
58
+ self.side_centerline_conv1 = nn.Conv2d(ngf2, out_nc, kernel_size=1, stride=1, bias=False)
59
+
60
+ self.centerline_conv2 = nn.Sequential(*self._conv_block(ngf2, ngf2*2, norm_layer, use_selu, num_block=2))
61
+ self.side_centerline_conv2 = nn.Conv2d(ngf2*2, out_nc, kernel_size=1, stride=1, bias=False)
62
+
63
+ self.centerline_conv3 = nn.Sequential(*self._conv_block(ngf2*2, ngf2*4, norm_layer, use_selu, num_block=2))
64
+ self.side_centerline_conv3 = nn.Conv2d(ngf2*4, out_nc, kernel_size=1, stride=1, bias=False)
65
+
66
+ self.centerline_conv4 = nn.Sequential(*self._conv_block(ngf2*4, ngf2*8, norm_layer, use_selu, num_block=2))
67
+ self.side_centerline_conv4 = nn.Conv2d(ngf2*8, out_nc, kernel_size=1, stride=1, bias=False)
68
+
69
+ self.fuse_centerline_conv = nn.Conv2d(out_nc*4, out_nc, kernel_size=1, stride=1, bias=False)
70
+
71
+ self.maxpool = nn.MaxPool2d(2, stride=2)
72
+
73
+ #self.up2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
74
+ #self.up4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
75
+ #self.up8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
76
+ #self.up16 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True)
77
+
78
+ def _conv_block(self, in_nc, out_nc, norm_layer, use_selu, num_block=2, kernel_size=3,
79
+ stride=1, padding=1, bias=True):
80
+ conv = []
81
+ for i in range(num_block):
82
+ cur_in_nc = in_nc if i == 0 else out_nc
83
+ conv += [nn.Conv2d(cur_in_nc, out_nc, kernel_size=kernel_size, stride=stride,
84
+ padding=padding, bias=bias)]
85
+ if use_selu:
86
+ conv += [nn.SeLU(True)]
87
+ else:
88
+ conv += [norm_layer(out_nc), nn.ReLU(True)]
89
+ return conv
90
+
91
+ def _segment_forward(self, x):
92
+ """
93
+ predict road surface segmentation
94
+ :param: x, image tensor, [N, C, H, W]
95
+ """
96
+ h,w = x.size()[2:]
97
+ # main stream features
98
+ conv1 = self.segment_conv1(x)
99
+ conv2 = self.segment_conv2(self.maxpool(conv1))
100
+ conv3 = self.segment_conv3(self.maxpool(conv2))
101
+ conv4 = self.segment_conv4(self.maxpool(conv3))
102
+ conv5 = self.segment_conv5(self.maxpool(conv4))
103
+ # side output features
104
+ side_output1 = self.side_segment_conv1(conv1)
105
+ side_output2 = self.side_segment_conv2(conv2)
106
+ side_output3 = self.side_segment_conv3(conv3)
107
+ side_output4 = self.side_segment_conv4(conv4)
108
+ side_output5 = self.side_segment_conv5(conv5)
109
+ # upsampling side output features
110
+ side_output2 = F.interpolate(side_output2, size=(h, w), mode='bilinear', align_corners=True) #self.up2(side_output2)
111
+ side_output3 = F.interpolate(side_output3, size=(h, w), mode='bilinear', align_corners=True) #self.up4(side_output3)
112
+ side_output4 = F.interpolate(side_output4, size=(h, w), mode='bilinear', align_corners=True) #self.up8(side_output4)
113
+ side_output5 = F.interpolate(side_output5, size=(h, w), mode='bilinear', align_corners=True) #self.up16(side_output5)
114
+
115
+ fused = self.fuse_segment_conv(torch.cat([
116
+ side_output1,
117
+ side_output2,
118
+ side_output3,
119
+ side_output4,
120
+ side_output5], dim=1))
121
+ return [side_output1, side_output2, side_output3, side_output4, side_output5, fused]
122
+
123
+ def _edge_forward(self, x):
124
+ """
125
+ predict road edge
126
+ :param: x, [image tensor, predicted segmentation tensor], [N, C+1, H, W]
127
+ """
128
+ h, w = x.size()[2:]
129
+ # main stream features
130
+ conv1 = self.edge_conv1(x)
131
+ conv2 = self.edge_conv2(self.maxpool(conv1))
132
+ conv3 = self.edge_conv3(self.maxpool(conv2))
133
+ conv4 = self.edge_conv4(self.maxpool(conv3))
134
+ # side output features
135
+ side_output1 = self.side_edge_conv1(conv1)
136
+ side_output2 = self.side_edge_conv2(conv2)
137
+ side_output3 = self.side_edge_conv3(conv3)
138
+ side_output4 = self.side_edge_conv4(conv4)
139
+ # upsampling side output features
140
+ side_output2 = F.interpolate(side_output2, size=(h, w), mode='bilinear', align_corners=True) #self.up2(side_output2)
141
+ side_output3 = F.interpolate(side_output3, size=(h, w), mode='bilinear', align_corners=True) #self.up4(side_output3)
142
+ side_output4 = F.interpolate(side_output4, size=(h, w), mode='bilinear', align_corners=True) #self.up8(side_output4)
143
+ fused = self.fuse_edge_conv(torch.cat([
144
+ side_output1,
145
+ side_output2,
146
+ side_output3,
147
+ side_output4], dim=1))
148
+ return [side_output1, side_output2, side_output3, side_output4, fused]
149
+
150
+ def _centerline_forward(self, x):
151
+ """
152
+ predict road edge
153
+ :param: x, [image tensor, predicted segmentation tensor], [N, C+1, H, W]
154
+ """
155
+ h,w = x.size()[2:]
156
+ # main stream features
157
+ conv1 = self.centerline_conv1(x)
158
+ conv2 = self.centerline_conv2(self.maxpool(conv1))
159
+ conv3 = self.centerline_conv3(self.maxpool(conv2))
160
+ conv4 = self.centerline_conv4(self.maxpool(conv3))
161
+ # side output features
162
+ side_output1 = self.side_centerline_conv1(conv1)
163
+ side_output2 = self.side_centerline_conv2(conv2)
164
+ side_output3 = self.side_centerline_conv3(conv3)
165
+ side_output4 = self.side_centerline_conv4(conv4)
166
+ # upsampling side output features
167
+ side_output2 = F.interpolate(side_output2, size=(h, w), mode='bilinear', align_corners=True) #self.up2(side_output2)
168
+ side_output3 = F.interpolate(side_output3, size=(h, w), mode='bilinear', align_corners=True) #self.up4(side_output3)
169
+ side_output4 = F.interpolate(side_output4, size=(h, w), mode='bilinear', align_corners=True) #self.up8(side_output4)
170
+ fused = self.fuse_centerline_conv(torch.cat([
171
+ side_output1,
172
+ side_output2,
173
+ side_output3,
174
+ side_output4], dim=1))
175
+ return [side_output1, side_output2, side_output3, side_output4, fused]
176
+
177
+ def forward(self, x):
178
+ segments = self._segment_forward(x)
179
+
180
+ x_ = torch.cat([x, segments[-1]], dim=1)
181
+ edges = self._edge_forward(x_)
182
+ centerlines = self._centerline_forward(x_)
183
+ return segments, edges, centerlines
184
+
185
+ def define_roadnet(in_nc,
186
+ out_nc,
187
+ ngf,
188
+ norm='batch',
189
+ use_selu=1,
190
+ init_type='xavier',
191
+ init_gain=0.02,
192
+ gpu_ids=[]):
193
+ net = RoadNet(in_nc, out_nc, ngf, norm, use_selu)
194
+ return init_net(net, init_type, init_gain, gpu_ids)