Thompson001 commited on
Commit
c3d31f9
·
verified ·
1 Parent(s): 873a70d

Delete models

Browse files
models/__init__.py DELETED
@@ -1,68 +0,0 @@
1
- """This package contains modules related to objective functions, optimizations, and network architectures.
2
-
3
- To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
- You need to implement the following five functions:
5
- -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
- -- <set_input>: unpack data from dataset and apply preprocessing.
7
- -- <forward>: produce intermediate results.
8
- -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
- -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
-
11
- In the function <__init__>, you need to define four lists:
12
- -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
- -- self.model_names (str list): specify the images that you want to display and save.
14
- -- self.visual_names (str list): define networks used in our training.
15
- -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16
-
17
- Now you can use the model class by specifying flag '--model dummy'.
18
- See our template model class 'template_model.py' for more details.
19
- """
20
-
21
- import importlib
22
- from .base_model import BaseModel
23
-
24
-
25
- def find_model_using_name(model_name):
26
- """Import the module "models/[model_name]_model.py".
27
-
28
- In the file, the class called DatasetNameModel() will
29
- be instantiated. It has to be a subclass of BaseModel,
30
- and it is case-insensitive.
31
- """
32
- model_filename = "models." + model_name + "_model"
33
- modellib = importlib.import_module(model_filename)
34
- model = None
35
- target_model_name = model_name.replace('_', '') + 'model'
36
- for name, cls in modellib.__dict__.items():
37
- if name.lower() == target_model_name.lower() \
38
- and issubclass(cls, BaseModel):
39
- model = cls
40
-
41
- if model is None:
42
- print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43
- exit(0)
44
-
45
- return model
46
-
47
-
48
- def get_option_setter(model_name):
49
- """Return the static method <modify_commandline_options> of the model class."""
50
- model_class = find_model_using_name(model_name)
51
- return model_class.modify_commandline_options
52
-
53
-
54
- def create_model(opt):
55
- """Create a model given the option.
56
-
57
- This function warps the class CustomDatasetDataLoader.
58
- This is the main interface between this package and 'train.py'/'test.py'
59
-
60
- Example:
61
- >>> from models import create_model
62
- >>> model = create_model(opt)
63
- """
64
- model = find_model_using_name(opt.model)
65
- print(model)
66
- instance = model(opt)
67
- print("model [%s] was created" % type(instance).__name__)
68
- return instance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/base_model.py DELETED
@@ -1,229 +0,0 @@
1
- import os
2
- import torch
3
- from collections import OrderedDict
4
- from abc import ABC, abstractmethod
5
- from . import networks
6
-
7
-
8
- class BaseModel(ABC):
9
- """This class is an abstract base class (ABC) for models.
10
- To create a subclass, you need to implement the following five functions:
11
- -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
12
- -- <set_input>: unpack data from dataset and apply preprocessing.
13
- -- <forward>: produce intermediate results.
14
- -- <optimize_parameters>: calculate losses, gradients, and update network weights.
15
- -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
16
- """
17
-
18
- def __init__(self, opt):
19
- """Initialize the BaseModel class.
20
-
21
- Parameters:
22
- opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
23
-
24
- When creating your custom class, you need to implement your own initialization.
25
- In this fucntion, you should first call <BaseModel.__init__(self, opt)>
26
- Then, you need to define four lists:
27
- -- self.loss_names (str list): specify the training losses that you want to plot and save.
28
- -- self.model_names (str list): specify the images that you want to display and save.
29
- -- self.visual_names (str list): define networks used in our training.
30
- -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
31
- """
32
- self.opt = opt
33
- self.gpu_ids = opt.gpu_ids
34
- self.isTrain = opt.isTrain
35
- self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
36
- if hasattr(opt, 'checkpoints_dir'):
37
- self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
38
- if not hasattr(opt, 'preprocess') or opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
39
- torch.backends.cudnn.benchmark = True
40
- self.loss_names = []
41
- self.model_names = []
42
- self.visual_names = []
43
- self.optimizers = []
44
- self.image_paths = []
45
- self.metric = 0 # used for learning rate policy 'plateau'
46
-
47
- @staticmethod
48
- def modify_commandline_options(parser, is_train):
49
- """Add new model-specific options, and rewrite default values for existing options.
50
-
51
- Parameters:
52
- parser -- original option parser
53
- is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
54
-
55
- Returns:
56
- the modified parser.
57
- """
58
- return parser
59
-
60
- @abstractmethod
61
- def set_input(self, input):
62
- """Unpack input data from the dataloader and perform necessary pre-processing steps.
63
-
64
- Parameters:
65
- input (dict): includes the data itself and its metadata information.
66
- """
67
- pass
68
-
69
- @abstractmethod
70
- def forward(self):
71
- """Run forward pass; called by both functions <optimize_parameters> and <test>."""
72
- pass
73
-
74
- @abstractmethod
75
- def optimize_parameters(self):
76
- """Calculate losses, gradients, and update network weights; called in every training iteration"""
77
- pass
78
-
79
- def setup(self, opt):
80
- """Load and print networks; create schedulers
81
-
82
- Parameters:
83
- opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
84
- """
85
- if self.isTrain:
86
- self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
87
- if not self.isTrain or opt.continue_train:
88
- load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
89
- self.load_networks(load_suffix)
90
- self.print_networks(opt.verbose)
91
-
92
- def eval(self):
93
- """Make models eval mode during test time"""
94
- for name in self.model_names:
95
- if isinstance(name, str):
96
- net = getattr(self, 'net' + name)
97
- net.eval()
98
-
99
- def test(self):
100
- """Forward function used in test time.
101
-
102
- This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
103
- It also calls <compute_visuals> to produce additional visualization results
104
- """
105
- with torch.no_grad():
106
- self.forward()
107
- self.compute_visuals()
108
-
109
- def compute_visuals(self):
110
- """Calculate additional output images for visdom and HTML visualization"""
111
- pass
112
-
113
- def get_image_paths(self):
114
- """ Return image paths that are used to load current data"""
115
- return self.image_paths
116
-
117
- def update_learning_rate(self):
118
- """Update learning rates for all the networks; called at the end of every epoch"""
119
- for scheduler in self.schedulers:
120
- if self.opt.lr_policy == 'plateau':
121
- scheduler.step(self.metric)
122
- else:
123
- scheduler.step()
124
- lr = self.optimizers[0].param_groups[0]['lr']
125
- print('learning rate = %.7f' % lr)
126
-
127
- def get_current_visuals(self):
128
- """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
129
- visual_ret = OrderedDict()
130
- for name in self.visual_names:
131
- if isinstance(name, str):
132
- visual_ret[name] = getattr(self, name)
133
- return visual_ret
134
-
135
- def get_current_losses(self):
136
- """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
137
- errors_ret = OrderedDict()
138
- for name in self.loss_names:
139
- if isinstance(name, str):
140
- errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
141
- return errors_ret
142
-
143
- def save_networks(self, epoch):
144
- """Save all the networks to the disk.
145
-
146
- Parameters:
147
- epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
148
- """
149
- for name in self.model_names:
150
- if isinstance(name, str):
151
- save_filename = '%s_net_%s.pth' % (epoch, name)
152
- save_path = os.path.join(self.save_dir, save_filename)
153
- net = getattr(self, 'net' + name)
154
-
155
- if len(self.gpu_ids) > 0 and torch.cuda.is_available():
156
- torch.save(net.module.cpu().state_dict(), save_path)
157
- net.cuda(self.gpu_ids[0])
158
- else:
159
- torch.save(net.cpu().state_dict(), save_path)
160
-
161
- def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
162
- """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
163
- key = keys[i]
164
- if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
165
- if module.__class__.__name__.startswith('InstanceNorm') and \
166
- (key == 'running_mean' or key == 'running_var'):
167
- if getattr(module, key) is None:
168
- state_dict.pop('.'.join(keys))
169
- if module.__class__.__name__.startswith('InstanceNorm') and \
170
- (key == 'num_batches_tracked'):
171
- state_dict.pop('.'.join(keys))
172
- else:
173
- self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
174
-
175
- def load_networks(self, epoch):
176
- """Load all the networks from the disk.
177
-
178
- Parameters:
179
- epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
180
- """
181
- for name in self.model_names:
182
- if isinstance(name, str):
183
- load_filename = '%s_net_%s.pth' % (epoch, name)
184
- load_path = os.path.join(self.save_dir, load_filename)
185
- net = getattr(self, 'net' + name)
186
- if isinstance(net, torch.nn.DataParallel):
187
- net = net.module
188
- print('loading the model from %s' % load_path)
189
- # if you are using PyTorch newer than 0.4 (e.g., built from
190
- # GitHub source), you can remove str() on self.device
191
- state_dict = torch.load(load_path, map_location=str(self.device))
192
- if hasattr(state_dict, '_metadata'):
193
- del state_dict._metadata
194
-
195
- # patch InstanceNorm checkpoints prior to 0.4
196
- for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
197
- self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
198
- net.load_state_dict(state_dict, strict=False)
199
-
200
- def print_networks(self, verbose):
201
- """Print the total number of parameters in the network and (if verbose) network architecture
202
-
203
- Parameters:
204
- verbose (bool) -- if verbose: print the network architecture
205
- """
206
- print('---------- Networks initialized -------------')
207
- for name in self.model_names:
208
- if isinstance(name, str):
209
- net = getattr(self, 'net' + name)
210
- num_params = 0
211
- for param in net.parameters():
212
- num_params += param.numel()
213
- if verbose:
214
- print(net)
215
- print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
216
- print('-----------------------------------------------')
217
-
218
- def set_requires_grad(self, nets, requires_grad=False):
219
- """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
220
- Parameters:
221
- nets (network list) -- a list of networks
222
- requires_grad (bool) -- whether the networks require gradients or not
223
- """
224
- if not isinstance(nets, list):
225
- nets = [nets]
226
- for net in nets:
227
- if net is not None:
228
- for param in net.parameters():
229
- param.requires_grad = requires_grad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/deepcrack_model.py DELETED
@@ -1,116 +0,0 @@
1
- # Author: Yahui Liu <yahui.liu@uintn.it>
2
-
3
- import torch
4
- import numpy as np
5
- import itertools
6
- from .base_model import BaseModel
7
- from .deepcrack_networks import define_deepcrack, BinaryFocalLoss
8
-
9
- class DeepCrackModel(BaseModel):
10
- """
11
- This class implements the DeepCrack model.
12
- DeepCrack paper: https://www.sciencedirect.com/science/article/pii/S0925231219300566
13
- """
14
- @staticmethod
15
- def modify_commandline_options(parser, is_train=True):
16
- """Add new dataset-specific options, and rewrite default values for existing options."""
17
- parser.add_argument('--lambda_side', type=float, default=1.0, help='weight for side output loss')
18
- parser.add_argument('--lambda_fused', type=float, default=1.0, help='weight for fused loss')
19
- return parser
20
-
21
- def __init__(self, opt):
22
- """Initialize the DeepCrack class.
23
- Parameters:
24
- opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
25
- """
26
- BaseModel.__init__(self, opt)
27
- # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
28
- self.loss_names = ['side', 'fused', 'total']
29
- # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
30
- self.display_sides = opt.display_sides
31
- self.visual_names = ['image', 'label_viz', 'fused']
32
- if self.display_sides:
33
- self.visual_names += ['side1', 'side2', 'side3', 'side4', 'side5']
34
- # specify the models you want to save to the disk.
35
- self.model_names = ['G']
36
-
37
- # define networks
38
- self.netG = define_deepcrack(opt.input_nc,
39
- opt.num_classes,
40
- opt.ngf,
41
- opt.norm,
42
- opt.init_type,
43
- opt.init_gain,
44
- self.gpu_ids)
45
-
46
- self.softmax = torch.nn.Softmax(dim=1)
47
-
48
- if self.isTrain:
49
- # define loss functions
50
- #self.weight = torch.from_numpy(np.array([0.0300, 1.0000], dtype='float32')).float().to(self.device)
51
- #self.criterionSeg = torch.nn.CrossEntropyLoss(weight=self.weight)
52
- if self.opt.loss_mode == 'focal':
53
- self.criterionSeg = BinaryFocalLoss()
54
- else:
55
- self.criterionSeg = nn.BCEWithLogitsLoss(size_average=True, reduce=True,
56
- pos_weight=torch.tensor(1.0/3e-2).to(self.device))
57
- self.weight_side = [0.5, 0.75, 1.0, 0.75, 0.5]
58
-
59
- # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
60
- self.optimizer = torch.optim.SGD(self.netG.parameters(), lr=opt.lr, momentum=0.9, weight_decay=2e-4)
61
- self.optimizers.append(self.optimizer)
62
-
63
- def set_input(self, input):
64
- """Unpack input data from the dataloader and perform necessary pre-processing steps.
65
- Parameters:
66
- input (dict): include the data itself and its metadata information.
67
- """
68
- self.image = input['image'].to(self.device)
69
- self.label = input['label'].to(self.device)
70
- #self.label3d = self.label.squeeze(1)
71
- self.image_paths = input['A_paths']
72
-
73
- def forward(self):
74
- """Run forward pass; called by both functions <optimize_parameters> and <test>."""
75
- self.outputs = self.netG(self.image)
76
-
77
- # for visualization
78
- self.label_viz = (self.label.float()-0.5)/0.5
79
- #self.fused = (self.softmax(self.outputs[-1])[:,1].detach().unsqueeze(1)-0.5)/0.5
80
- #if self.display_sides:
81
- # self.side1 = (self.softmax(self.outputs[0])[:,1].detach().unsqueeze(1)-0.5)/0.5
82
- # self.side2 = (self.softmax(self.outputs[1])[:,1].detach().unsqueeze(1)-0.5)/0.5
83
- # self.side3 = (self.softmax(self.outputs[2])[:,1].detach().unsqueeze(1)-0.5)/0.5
84
- # self.side4 = (self.softmax(self.outputs[3])[:,1].detach().unsqueeze(1)-0.5)/0.5
85
- # self.side5 = (self.softmax(self.outputs[4])[:,1].detach().unsqueeze(1)-0.5)/0.5
86
- self.fused = (torch.sigmoid(self.outputs[-1])-0.5)/0.5
87
-
88
- if self.display_sides:
89
- self.side1 = (torch.sigmoid(self.outputs[0])-0.5)/0.5
90
- self.side2 = (torch.sigmoid(self.outputs[1])-0.5)/0.5
91
- self.side3 = (torch.sigmoid(self.outputs[2])-0.5)/0.5
92
- self.side4 = (torch.sigmoid(self.outputs[3])-0.5)/0.5
93
- self.side5 = (torch.sigmoid(self.outputs[4])-0.5)/0.5
94
-
95
- def backward(self):
96
- """Calculate the loss"""
97
- lambda_side = self.opt.lambda_side
98
- lambda_fused = self.opt.lambda_fused
99
-
100
- self.loss_side = 0.0
101
- for out, w in zip(self.outputs[:-1], self.weight_side):
102
- #self.loss_side += self.criterionSeg(out, self.label3d) * w
103
- self.loss_side += self.criterionSeg(out, self.label) * w
104
-
105
- #self.loss_fused = self.criterionSeg(self.outputs[-1], self.label3d)
106
- self.loss_fused = self.criterionSeg(self.outputs[-1], self.label)
107
- self.loss_total = self.loss_side * lambda_side + self.loss_fused * lambda_fused
108
- self.loss_total.backward()
109
-
110
- def optimize_parameters(self, epoch=None):
111
- """Calculate losses, gradients, and update network weights; called in every training iteration"""
112
- # forward
113
- self.forward() # compute predictions.
114
- self.optimizer.zero_grad() # set G's gradients to zero
115
- self.backward() # calculate gradients for G
116
- self.optimizer.step() # update G's weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/deepcrack_networks.py DELETED
@@ -1,110 +0,0 @@
1
- #! -*- coding: utf-8 -*-
2
- # Author: Yahui Liu <yahui.liu@unitn.it>
3
-
4
- """
5
- Reference:
6
-
7
- DeepCrack: A deep hierarchical feature learning architecture for crack segmentation.
8
- https://www.sciencedirect.com/science/article/pii/S0925231219300566
9
- """
10
-
11
- import torch
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
- from .networks import get_norm_layer, init_net
15
-
16
- class DeepCrackNet(nn.Module):
17
- def __init__(self, in_nc, num_classes, ngf, norm='batch'):
18
- super(DeepCrackNet, self).__init__()
19
-
20
- norm_layer = get_norm_layer(norm_type=norm)
21
- self.conv1 = nn.Sequential(*self._conv_block(in_nc, ngf, norm_layer, num_block=2))
22
- self.side_conv1 = nn.Conv2d(ngf, num_classes, kernel_size=1, stride=1, bias=False)
23
-
24
- self.conv2 = nn.Sequential(*self._conv_block(ngf, ngf*2, norm_layer, num_block=2))
25
- self.side_conv2 = nn.Conv2d(ngf*2, num_classes, kernel_size=1, stride=1, bias=False)
26
-
27
- self.conv3 = nn.Sequential(*self._conv_block(ngf*2, ngf*4, norm_layer, num_block=3))
28
- self.side_conv3 = nn.Conv2d(ngf*4, num_classes, kernel_size=1, stride=1, bias=False)
29
-
30
- self.conv4 = nn.Sequential(*self._conv_block(ngf*4, ngf*8, norm_layer, num_block=3))
31
- self.side_conv4 = nn.Conv2d(ngf*8, num_classes, kernel_size=1, stride=1, bias=False)
32
-
33
- self.conv5 = nn.Sequential(*self._conv_block(ngf*8, ngf*8, norm_layer, num_block=3))
34
- self.side_conv5 = nn.Conv2d(ngf*8, num_classes, kernel_size=1, stride=1, bias=False)
35
-
36
- self.fuse_conv = nn.Conv2d(num_classes*5, num_classes, kernel_size=1, stride=1, bias=False)
37
- self.maxpool = nn.MaxPool2d(2, stride=2)
38
-
39
- #self.up2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
40
- #self.up4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
41
- #self.up8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
42
- #self.up16 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True)
43
-
44
- def _conv_block(self, in_nc, out_nc, norm_layer, num_block=2, kernel_size=3,
45
- stride=1, padding=1, bias=False):
46
- conv = []
47
- for i in range(num_block):
48
- cur_in_nc = in_nc if i == 0 else out_nc
49
- conv += [nn.Conv2d(cur_in_nc, out_nc, kernel_size=kernel_size, stride=stride,
50
- padding=padding, bias=bias),
51
- norm_layer(out_nc),
52
- nn.ReLU(True)]
53
- return conv
54
-
55
- def forward(self, x):
56
- h,w = x.size()[2:]
57
- # main stream features
58
- conv1 = self.conv1(x)
59
- conv2 = self.conv2(self.maxpool(conv1))
60
- conv3 = self.conv3(self.maxpool(conv2))
61
- conv4 = self.conv4(self.maxpool(conv3))
62
- conv5 = self.conv5(self.maxpool(conv4))
63
- # side output features
64
- side_output1 = self.side_conv1(conv1)
65
- side_output2 = self.side_conv2(conv2)
66
- side_output3 = self.side_conv3(conv3)
67
- side_output4 = self.side_conv4(conv4)
68
- side_output5 = self.side_conv5(conv5)
69
- # upsampling side output features
70
- side_output2 = F.interpolate(side_output2, size=(h, w), mode='bilinear', align_corners=True) #self.up2(side_output2)
71
- side_output3 = F.interpolate(side_output3, size=(h, w), mode='bilinear', align_corners=True) #self.up4(side_output3)
72
- side_output4 = F.interpolate(side_output4, size=(h, w), mode='bilinear', align_corners=True) #self.up8(side_output4)
73
- side_output5 = F.interpolate(side_output5, size=(h, w), mode='bilinear', align_corners=True) #self.up16(side_output5)
74
-
75
- fused = self.fuse_conv(torch.cat([side_output1,
76
- side_output2,
77
- side_output3,
78
- side_output4,
79
- side_output5], dim=1))
80
- return side_output1, side_output2, side_output3, side_output4, side_output5, fused
81
-
82
- def define_deepcrack(in_nc,
83
- num_classes,
84
- ngf,
85
- norm='batch',
86
- init_type='xavier',
87
- init_gain=0.02,
88
- gpu_ids=[]):
89
- net = DeepCrackNet(in_nc, num_classes, ngf, norm)
90
- return init_net(net, init_type, init_gain, gpu_ids)
91
-
92
-
93
- class BinaryFocalLoss(nn.Module):
94
- def __init__(self, alpha=1, gamma=2, logits=False, size_average=True):
95
- super(BinaryFocalLoss, self).__init__()
96
- self.alpha = alpha
97
- self.gamma = gamma
98
- self.logits = logits
99
- self.size_average = size_average
100
- self.criterion = nn.BCEWithLogitsLoss(reduction='none')
101
-
102
- def forward(self, inputs, targets):
103
- BCE_loss = self.criterion(inputs, targets)
104
- pt = torch.exp(-BCE_loss)
105
- F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
106
-
107
- if self.size_average:
108
- return F_loss.mean()
109
- else:
110
- return F_loss.sum()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/networks.py DELETED
@@ -1,609 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch.nn import init
4
- import functools
5
- from torch.optim import lr_scheduler
6
-
7
-
8
- ###############################################################################
9
- # Helper Functions
10
- ###############################################################################
11
- def get_norm_layer(norm_type='instance'):
12
- """Return a normalization layer
13
-
14
- Parameters:
15
- norm_type (str) -- the name of the normalization layer: batch | instance | none
16
-
17
- For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
18
- For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
19
- """
20
- if norm_type == 'batch':
21
- norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
22
- elif norm_type == 'instance':
23
- norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
24
- elif norm_type == 'none':
25
- norm_layer = None
26
- else:
27
- raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
28
- return norm_layer
29
-
30
-
31
- def get_scheduler(optimizer, opt):
32
- """Return a learning rate scheduler
33
-
34
- Parameters:
35
- optimizer -- the optimizer of the network
36
- opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 
37
- opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
38
-
39
- For 'linear', we keep the same learning rate for the first <opt.niter> epochs
40
- and linearly decay the rate to zero over the next <opt.niter_decay> epochs.
41
- For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
42
- See https://pytorch.org/docs/stable/optim.html for more details.
43
- """
44
- if opt.lr_policy == 'linear':
45
- def lambda_rule(epoch):
46
- lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
47
- return lr_l
48
- scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
49
- elif opt.lr_policy == 'step':
50
- scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
51
- elif opt.lr_policy == 'plateau':
52
- scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
53
- elif opt.lr_policy == 'cosine':
54
- scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
55
- else:
56
- return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
57
- return scheduler
58
-
59
-
60
- def init_weights(net, init_type='normal', init_gain=0.02):
61
- """Initialize network weights.
62
-
63
- Parameters:
64
- net (network) -- network to be initialized
65
- init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
66
- init_gain (float) -- scaling factor for normal, xavier and orthogonal.
67
-
68
- We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
69
- work better for some applications. Feel free to try yourself.
70
- """
71
- def init_func(m): # define the initialization function
72
- classname = m.__class__.__name__
73
- if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
74
- if init_type == 'normal':
75
- init.normal_(m.weight.data, 0.0, init_gain)
76
- elif init_type == 'xavier':
77
- init.xavier_normal_(m.weight.data, gain=init_gain)
78
- elif init_type == 'kaiming':
79
- init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
80
- elif init_type == 'orthogonal':
81
- init.orthogonal_(m.weight.data, gain=init_gain)
82
- else:
83
- raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
84
- if hasattr(m, 'bias') and m.bias is not None:
85
- init.constant_(m.bias.data, 0.0)
86
- elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
87
- init.normal_(m.weight.data, 1.0, init_gain)
88
- init.constant_(m.bias.data, 0.0)
89
-
90
- print('initialize network with %s' % init_type)
91
- net.apply(init_func) # apply the initialization function <init_func>
92
-
93
-
94
- def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
95
- """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
96
- Parameters:
97
- net (network) -- the network to be initialized
98
- init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
99
- gain (float) -- scaling factor for normal, xavier and orthogonal.
100
- gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
101
-
102
- Return an initialized network.
103
- """
104
- if len(gpu_ids) > 0:
105
- assert(torch.cuda.is_available())
106
- net.to(gpu_ids[0])
107
- net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
108
- init_weights(net, init_type, init_gain=init_gain)
109
- return net
110
-
111
-
112
- def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
113
- """Create a generator
114
-
115
- Parameters:
116
- input_nc (int) -- the number of channels in input images
117
- output_nc (int) -- the number of channels in output images
118
- ngf (int) -- the number of filters in the last conv layer
119
- netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
120
- norm (str) -- the name of normalization layers used in the network: batch | instance | none
121
- use_dropout (bool) -- if use dropout layers.
122
- init_type (str) -- the name of our initialization method.
123
- init_gain (float) -- scaling factor for normal, xavier and orthogonal.
124
- gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
125
-
126
- Returns a generator
127
-
128
- Our current implementation provides two types of generators:
129
- U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
130
- The original U-Net paper: https://arxiv.org/abs/1505.04597
131
-
132
- Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
133
- Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
134
- We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
135
-
136
-
137
- The generator has been initialized by <init_net>. It uses RELU for non-linearity.
138
- """
139
- net = None
140
- norm_layer = get_norm_layer(norm_type=norm)
141
-
142
- if netG == 'resnet_9blocks':
143
- net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
144
- elif netG == 'resnet_6blocks':
145
- net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
146
- elif netG == 'unet_128':
147
- net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
148
- elif netG == 'unet_256':
149
- net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
150
- else:
151
- raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
152
- return init_net(net, init_type, init_gain, gpu_ids)
153
-
154
-
155
- def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
156
- """Create a discriminator
157
-
158
- Parameters:
159
- input_nc (int) -- the number of channels in input images
160
- ndf (int) -- the number of filters in the first conv layer
161
- netD (str) -- the architecture's name: basic | n_layers | pixel
162
- n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
163
- norm (str) -- the type of normalization layers used in the network.
164
- init_type (str) -- the name of the initialization method.
165
- init_gain (float) -- scaling factor for normal, xavier and orthogonal.
166
- gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
167
-
168
- Returns a discriminator
169
-
170
- Our current implementation provides three types of discriminators:
171
- [basic]: 'PatchGAN' classifier described in the original pix2pix paper.
172
- It can classify whether 70×70 overlapping patches are real or fake.
173
- Such a patch-level discriminator architecture has fewer parameters
174
- than a full-image discriminator and can work on arbitrarily-sized images
175
- in a fully convolutional fashion.
176
-
177
- [n_layers]: With this mode, you cna specify the number of conv layers in the discriminator
178
- with the parameter <n_layers_D> (default=3 as used in [basic] (PatchGAN).)
179
-
180
- [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
181
- It encourages greater color diversity but has no effect on spatial statistics.
182
-
183
- The discriminator has been initialized by <init_net>. It uses Leakly RELU for non-linearity.
184
- """
185
- net = None
186
- norm_layer = get_norm_layer(norm_type=norm)
187
-
188
- if netD == 'basic': # default PatchGAN classifier
189
- net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
190
- elif netD == 'n_layers': # more options
191
- net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
192
- elif netD == 'pixel': # classify if each pixel is real or fake
193
- net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
194
- else:
195
- raise NotImplementedError('Discriminator model name [%s] is not recognized' % net)
196
- return init_net(net, init_type, init_gain, gpu_ids)
197
-
198
-
199
- ##############################################################################
200
- # Classes
201
- ##############################################################################
202
- class GANLoss(nn.Module):
203
- """Define different GAN objectives.
204
-
205
- The GANLoss class abstracts away the need to create the target label tensor
206
- that has the same size as the input.
207
- """
208
-
209
- def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
210
- """ Initialize the GANLoss class.
211
-
212
- Parameters:
213
- gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
214
- target_real_label (bool) - - label for a real image
215
- target_fake_label (bool) - - label of a fake image
216
-
217
- Note: Do not use sigmoid as the last layer of Discriminator.
218
- LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
219
- """
220
- super(GANLoss, self).__init__()
221
- self.register_buffer('real_label', torch.tensor(target_real_label))
222
- self.register_buffer('fake_label', torch.tensor(target_fake_label))
223
- self.gan_mode = gan_mode
224
- if gan_mode == 'lsgan':
225
- self.loss = nn.MSELoss()
226
- elif gan_mode == 'vanilla':
227
- self.loss = nn.BCEWithLogitsLoss()
228
- elif gan_mode in ['wgangp']:
229
- self.loss = None
230
- else:
231
- raise NotImplementedError('gan mode %s not implemented' % gan_mode)
232
-
233
- def get_target_tensor(self, prediction, target_is_real):
234
- """Create label tensors with the same size as the input.
235
-
236
- Parameters:
237
- prediction (tensor) - - tpyically the prediction from a discriminator
238
- target_is_real (bool) - - if the ground truth label is for real images or fake images
239
-
240
- Returns:
241
- A label tensor filled with ground truth label, and with the size of the input
242
- """
243
-
244
- if target_is_real:
245
- target_tensor = self.real_label
246
- else:
247
- target_tensor = self.fake_label
248
- return target_tensor.expand_as(prediction)
249
-
250
- def __call__(self, prediction, target_is_real):
251
- """Calculate loss given Discriminator's output and grount truth labels.
252
-
253
- Parameters:
254
- prediction (tensor) - - tpyically the prediction output from a discriminator
255
- target_is_real (bool) - - if the ground truth label is for real images or fake images
256
-
257
- Returns:
258
- the calculated loss.
259
- """
260
- if self.gan_mode in ['lsgan', 'vanilla']:
261
- target_tensor = self.get_target_tensor(prediction, target_is_real)
262
- loss = self.loss(prediction, target_tensor)
263
- elif self.gan_mode == 'wgangp':
264
- if target_is_real:
265
- loss = -prediction.mean()
266
- else:
267
- loss = prediction.mean()
268
- return loss
269
-
270
-
271
- def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
272
- """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
273
-
274
- Arguments:
275
- netD (network) -- discriminator network
276
- real_data (tensor array) -- real images
277
- fake_data (tensor array) -- generated images from the generator
278
- device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
279
- type (str) -- if we mix real and fake data or not [real | fake | mixed].
280
- constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
281
- lambda_gp (float) -- weight for this loss
282
-
283
- Returns the gradient penalty loss
284
- """
285
- if lambda_gp > 0.0:
286
- if type == 'real': # either use real images, fake images, or a linear interpolation of two.
287
- interpolatesv = real_data
288
- elif type == 'fake':
289
- interpolatesv = fake_data
290
- elif type == 'mixed':
291
- alpha = torch.rand(real_data.shape[0], 1)
292
- alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
293
- alpha = alpha.to(device)
294
- interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
295
- else:
296
- raise NotImplementedError('{} not implemented'.format(type))
297
- interpolatesv.requires_grad_(True)
298
- disc_interpolates = netD(interpolatesv)
299
- gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
300
- grad_outputs=torch.ones(disc_interpolates.size()).to(device),
301
- create_graph=True, retain_graph=True, only_inputs=True)
302
- gradients = gradients[0].view(real_data.size(0), -1) # flat the data
303
- gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
304
- return gradient_penalty, gradients
305
- else:
306
- return 0.0, None
307
-
308
-
309
- class ResnetGenerator(nn.Module):
310
- """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
311
-
312
- We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
313
- """
314
-
315
- def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
316
- """Construct a Resnet-based generator
317
-
318
- Parameters:
319
- input_nc (int) -- the number of channels in input images
320
- output_nc (int) -- the number of channels in output images
321
- ngf (int) -- the number of filters in the last conv layer
322
- norm_layer -- normalization layer
323
- use_dropout (bool) -- if use dropout layers
324
- n_blocks (int) -- the number of ResNet blocks
325
- padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
326
- """
327
- assert(n_blocks >= 0)
328
- super(ResnetGenerator, self).__init__()
329
- if type(norm_layer) == functools.partial:
330
- use_bias = norm_layer.func == nn.InstanceNorm2d
331
- else:
332
- use_bias = norm_layer == nn.InstanceNorm2d
333
-
334
- model = [nn.ReflectionPad2d(3),
335
- nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
336
- norm_layer(ngf),
337
- nn.ReLU(True)]
338
-
339
- n_downsampling = 2
340
- for i in range(n_downsampling): # add downsampling layers
341
- mult = 2 ** i
342
- model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
343
- norm_layer(ngf * mult * 2),
344
- nn.ReLU(True)]
345
-
346
- mult = 2 ** n_downsampling
347
- for i in range(n_blocks): # add ResNet blocks
348
-
349
- model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
350
-
351
- for i in range(n_downsampling): # add upsampling layers
352
- mult = 2 ** (n_downsampling - i)
353
- model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
354
- kernel_size=3, stride=2,
355
- padding=1, output_padding=1,
356
- bias=use_bias),
357
- norm_layer(int(ngf * mult / 2)),
358
- nn.ReLU(True)]
359
- model += [nn.ReflectionPad2d(3)]
360
- model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
361
- model += [nn.Tanh()]
362
-
363
- self.model = nn.Sequential(*model)
364
-
365
- def forward(self, input):
366
- """Standard forward"""
367
- return self.model(input)
368
-
369
-
370
- class ResnetBlock(nn.Module):
371
- """Define a Resnet block"""
372
-
373
- def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
374
- """Initialize the Resnet block
375
-
376
- A resnet block is a conv block with skip connections
377
- We construct a conv block with build_conv_block function,
378
- and implement skip connections in <forward> function.
379
- Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
380
- """
381
- super(ResnetBlock, self).__init__()
382
- self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
383
-
384
- def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
385
- """Construct a convolutional block.
386
-
387
- Parameters:
388
- dim (int) -- the number of channels in the conv layer.
389
- padding_type (str) -- the name of padding layer: reflect | replicate | zero
390
- norm_layer -- normalization layer
391
- use_dropout (bool) -- if use dropout layers.
392
- use_bias (bool) -- if the conv layer uses bias or not
393
-
394
- Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
395
- """
396
- conv_block = []
397
- p = 0
398
- if padding_type == 'reflect':
399
- conv_block += [nn.ReflectionPad2d(1)]
400
- elif padding_type == 'replicate':
401
- conv_block += [nn.ReplicationPad2d(1)]
402
- elif padding_type == 'zero':
403
- p = 1
404
- else:
405
- raise NotImplementedError('padding [%s] is not implemented' % padding_type)
406
-
407
- conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
408
- if use_dropout:
409
- conv_block += [nn.Dropout(0.5)]
410
-
411
- p = 0
412
- if padding_type == 'reflect':
413
- conv_block += [nn.ReflectionPad2d(1)]
414
- elif padding_type == 'replicate':
415
- conv_block += [nn.ReplicationPad2d(1)]
416
- elif padding_type == 'zero':
417
- p = 1
418
- else:
419
- raise NotImplementedError('padding [%s] is not implemented' % padding_type)
420
- conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
421
-
422
- return nn.Sequential(*conv_block)
423
-
424
- def forward(self, x):
425
- """Forward function (with skip connections)"""
426
- out = x + self.conv_block(x) # add skip connections
427
- return out
428
-
429
-
430
- class UnetGenerator(nn.Module):
431
- """Create a Unet-based generator"""
432
-
433
- def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
434
- """Construct a Unet generator
435
- Parameters:
436
- input_nc (int) -- the number of channels in input images
437
- output_nc (int) -- the number of channels in output images
438
- num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
439
- image of size 128x128 will become of size 1x1 # at the bottleneck
440
- ngf (int) -- the number of filters in the last conv layer
441
- norm_layer -- normalization layer
442
-
443
- We construct the U-Net from the innermost layer to the outermost layer.
444
- It is a recursive process.
445
- """
446
- super(UnetGenerator, self).__init__()
447
- # construct unet structure
448
- unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
449
- for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
450
- unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
451
- # gradually reduce the number of filters from ngf * 8 to ngf
452
- unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
453
- unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
454
- unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
455
- self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
456
-
457
- def forward(self, input):
458
- """Standard forward"""
459
- return self.model(input)
460
-
461
-
462
- class UnetSkipConnectionBlock(nn.Module):
463
- """Defines the Unet submodule with skip connection.
464
- X -------------------identity----------------------
465
- |-- downsampling -- |submodule| -- upsampling --|
466
- """
467
-
468
- def __init__(self, outer_nc, inner_nc, input_nc=None,
469
- submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
470
- """Construct a Unet submodule with skip connections.
471
-
472
- Parameters:
473
- outer_nc (int) -- the number of filters in the outer conv layer
474
- inner_nc (int) -- the number of filters in the inner conv layer
475
- input_nc (int) -- the number of channels in input images/features
476
- submodule (UnetSkipConnectionBlock) -- previously defined submodules
477
- outermost (bool) -- if this module is the outermost module
478
- innermost (bool) -- if this module is the innermost module
479
- norm_layer -- normalization layer
480
- user_dropout (bool) -- if use dropout layers.
481
- """
482
- super(UnetSkipConnectionBlock, self).__init__()
483
- self.outermost = outermost
484
- if type(norm_layer) == functools.partial:
485
- use_bias = norm_layer.func == nn.InstanceNorm2d
486
- else:
487
- use_bias = norm_layer == nn.InstanceNorm2d
488
- if input_nc is None:
489
- input_nc = outer_nc
490
- downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
491
- stride=2, padding=1, bias=use_bias)
492
- downrelu = nn.LeakyReLU(0.2, True)
493
- downnorm = norm_layer(inner_nc)
494
- uprelu = nn.ReLU(True)
495
- upnorm = norm_layer(outer_nc)
496
-
497
- if outermost:
498
- upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
499
- kernel_size=4, stride=2,
500
- padding=1)
501
- down = [downconv]
502
- up = [uprelu, upconv, nn.Tanh()]
503
- model = down + [submodule] + up
504
- elif innermost:
505
- upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
506
- kernel_size=4, stride=2,
507
- padding=1, bias=use_bias)
508
- down = [downrelu, downconv]
509
- up = [uprelu, upconv, upnorm]
510
- model = down + up
511
- else:
512
- upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
513
- kernel_size=4, stride=2,
514
- padding=1, bias=use_bias)
515
- down = [downrelu, downconv, downnorm]
516
- up = [uprelu, upconv, upnorm]
517
-
518
- if use_dropout:
519
- model = down + [submodule] + up + [nn.Dropout(0.5)]
520
- else:
521
- model = down + [submodule] + up
522
-
523
- self.model = nn.Sequential(*model)
524
-
525
- def forward(self, x):
526
- if self.outermost:
527
- return self.model(x)
528
- else: # add skip connections
529
- return torch.cat([x, self.model(x)], 1)
530
-
531
-
532
- class NLayerDiscriminator(nn.Module):
533
- """Defines a PatchGAN discriminator"""
534
-
535
- def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
536
- """Construct a PatchGAN discriminator
537
-
538
- Parameters:
539
- input_nc (int) -- the number of channels in input images
540
- ndf (int) -- the number of filters in the last conv layer
541
- n_layers (int) -- the number of conv layers in the discriminator
542
- norm_layer -- normalization layer
543
- """
544
- super(NLayerDiscriminator, self).__init__()
545
- if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
546
- use_bias = norm_layer.func != nn.BatchNorm2d
547
- else:
548
- use_bias = norm_layer != nn.BatchNorm2d
549
-
550
- kw = 4
551
- padw = 1
552
- sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
553
- nf_mult = 1
554
- nf_mult_prev = 1
555
- for n in range(1, n_layers): # gradually increase the number of filters
556
- nf_mult_prev = nf_mult
557
- nf_mult = min(2 ** n, 8)
558
- sequence += [
559
- nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
560
- norm_layer(ndf * nf_mult),
561
- nn.LeakyReLU(0.2, True)
562
- ]
563
-
564
- nf_mult_prev = nf_mult
565
- nf_mult = min(2 ** n_layers, 8)
566
- sequence += [
567
- nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
568
- norm_layer(ndf * nf_mult),
569
- nn.LeakyReLU(0.2, True)
570
- ]
571
-
572
- sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
573
- self.model = nn.Sequential(*sequence)
574
-
575
- def forward(self, input):
576
- """Standard forward."""
577
- return self.model(input)
578
-
579
-
580
- class PixelDiscriminator(nn.Module):
581
- """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
582
-
583
- def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
584
- """Construct a 1x1 PatchGAN discriminator
585
-
586
- Parameters:
587
- input_nc (int) -- the number of channels in input images
588
- ndf (int) -- the number of filters in the last conv layer
589
- norm_layer -- normalization layer
590
- """
591
- super(PixelDiscriminator, self).__init__()
592
- if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
593
- use_bias = norm_layer.func != nn.InstanceNorm2d
594
- else:
595
- use_bias = norm_layer != nn.InstanceNorm2d
596
-
597
- self.net = [
598
- nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
599
- nn.LeakyReLU(0.2, True),
600
- nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
601
- norm_layer(ndf * 2),
602
- nn.LeakyReLU(0.2, True),
603
- nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
604
-
605
- self.net = nn.Sequential(*self.net)
606
-
607
- def forward(self, input):
608
- """Standard forward."""
609
- return self.net(input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/roadnet_model.py DELETED
@@ -1,120 +0,0 @@
1
- # Author: Yahui Liu <yahui.liu@uintn.it>
2
-
3
- import torch
4
- import numpy as np
5
- import itertools
6
- from .base_model import BaseModel
7
- import torch.nn.functional as F
8
- from .roadnet_networks import define_roadnet
9
-
10
- class RoadNetModel(BaseModel):
11
- """
12
- This class implements the RoadNet model.
13
- RoadNet paper: https://ieeexplore.ieee.org/document/8506600
14
- """
15
- @staticmethod
16
- def modify_commandline_options(parser, is_train=True):
17
- """Add new dataset-specific options, and rewrite default values for existing options."""
18
- return parser
19
-
20
- def __init__(self, opt):
21
- """Initialize the RoadNet class.
22
-
23
- Parameters:
24
- opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
25
- """
26
- BaseModel.__init__(self, opt)
27
- # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
28
- self.loss_names = ['segment', 'edge', 'centerline']
29
- # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
30
- self.visual_names = ['image', 'label_gt', 'label_pred']
31
- # specify the models you want to save to the disk.
32
- self.model_names = ['G']
33
-
34
- # define networks
35
- self.netG = define_roadnet(opt.input_nc,
36
- opt.output_nc,
37
- opt.ngf,
38
- opt.norm,
39
- opt.use_selu,
40
- opt.init_type,
41
- opt.init_gain,
42
- self.gpu_ids)
43
-
44
- if self.isTrain:
45
- # define loss functions
46
- self.weight_segment_side = [0.5, 0.75, 1.0, 0.75, 0.5, 1.0]
47
- self.weight_others_side = [0.5, 0.75, 1.0, 0.75, 1.0]
48
-
49
- # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
50
- self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, eps=1e-3, weight_decay=2e-4)
51
- #self.optimizer = torch.optim.SGD(self.netG.parameters(), lr=opt.lr, momentum=0.9, weight_decay=2e-4)
52
- self.optimizers.append(self.optimizer)
53
-
54
- def _get_balanced_sigmoid_cross_entropy(self,x):
55
- count_neg = torch.sum(1. - x)
56
- count_pos = torch.sum(x)
57
- beta = count_neg / (count_neg + count_pos)
58
- pos_weight = beta / (1 - beta)
59
- cost = torch.nn.BCEWithLogitsLoss(size_average=True, reduce=True, pos_weight=pos_weight)
60
- return cost, 1-beta
61
-
62
- def set_input(self, input):
63
- """Unpack input data from the dataloader and perform necessary pre-processing steps.
64
-
65
- Parameters:
66
- input (dict): include the data itself and its metadata information.
67
- """
68
- self.image = input['image'].to(self.device)
69
- self.segment_gt = input['segment'].to(self.device)
70
- self.edge_gt = input['edge'].to(self.device)
71
- self.centerline_gt = input['centerline'].to(self.device)
72
- self.image_paths = input['A_paths']
73
- if self.isTrain:
74
- self.criterion_seg, self.beta_seg = self._get_balanced_sigmoid_cross_entropy(self.segment_gt)
75
- self.criterion_edg, self.beta_edg = self._get_balanced_sigmoid_cross_entropy(self.edge_gt)
76
- self.criterion_cnt, self.beta_cnt = self._get_balanced_sigmoid_cross_entropy(self.centerline_gt)
77
-
78
- def forward(self):
79
- """Run forward pass; called by both functions <optimize_parameters> and <test>."""
80
- self.segments, self.edges, self.centerlines = self.netG(self.image)
81
-
82
- # for visualization
83
- segment_gt_viz = (self.segment_gt-0.5)/0.5
84
- edge_gt_viz = (self.edge_gt-0.5)/0.5
85
- centerline_gt_viz = (self.centerline_gt-0.5)/0.5
86
- self.label_gt = torch.cat([centerline_gt_viz, edge_gt_viz, segment_gt_viz], dim=1)
87
-
88
- segment_fused = (torch.sigmoid(self.segments[-1])-0.5)/0.5
89
- edge_fused = (torch.sigmoid(self.edges[-1])-0.5)/0.5
90
- centerlines_fused = (torch.sigmoid(self.centerlines[-1])-0.5)/0.5
91
- self.label_pred = torch.cat([centerlines_fused, edge_fused, segment_fused], dim=1)
92
-
93
- def backward(self):
94
- """Calculate the loss"""
95
- self.loss_segment = torch.mean((torch.sigmoid(self.segments[-1])-self.segment_gt)**2) * 0.5
96
- if self.segment_gt.sum() > 0.0: # ignore blank ones
97
- for out, w in zip(self.segments, self.weight_segment_side):
98
- self.loss_segment += self.criterion_seg(out, self.segment_gt) * self.beta_seg * w
99
-
100
- self.loss_edge = torch.mean((torch.sigmoid(self.edges[-1])-self.edge_gt)**2) * 0.5
101
- if self.edge_gt.sum() > 0.0:
102
- for out, w in zip(self.edges, self.weight_others_side):
103
- self.loss_edge += self.criterion_edg(out, self.edge_gt) * self.beta_edg * w
104
-
105
- self.loss_centerline = torch.mean((torch.sigmoid(self.centerlines[-1])-self.centerline_gt)**2) * 0.5
106
- if self.centerline_gt.sum() > 0.0:
107
- for out, w in zip(self.centerlines, self.weight_others_side):
108
- self.loss_centerline += self.criterion_cnt(out, self.centerline_gt) * self.beta_cnt * w
109
-
110
- self.loss_total = self.loss_segment + self.loss_edge + self.loss_centerline
111
- self.loss_total.backward()
112
-
113
- def optimize_parameters(self, epoch=None):
114
- """Calculate losses, gradients, and update network weights; called in every training iteration"""
115
-
116
- # forward
117
- self.forward() # compute predictions.
118
- self.optimizer.zero_grad() # set G's gradients to zero
119
- self.backward() # calculate gradients for G
120
- self.optimizer.step() # update G's weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/roadnet_networks.py DELETED
@@ -1,194 +0,0 @@
1
- #! -*- coding: utf-8 -*-
2
- # Author: Yahui Liu <yahui.liu@unitn.it>
3
-
4
- """
5
- Reference:
6
-
7
- RoadNet: Learning to Comprehensively Analyze Road Networks in Complex Urban Scenes
8
- From High-Resolution Remotely Sensed Images.
9
- https://ieeexplore.ieee.org/document/8506600
10
- """
11
-
12
- import torch
13
- import torch.nn as nn
14
- import torch.nn.functional as F
15
- from .networks import get_norm_layer, init_net
16
-
17
- class RoadNet(nn.Module):
18
- def __init__(self, in_nc, out_nc, ngf, norm='batch', use_selu=1):
19
- super(RoadNet, self).__init__()
20
- norm_layer = get_norm_layer(norm_type=norm)
21
-
22
- #------------road surface segmentation------------#
23
- self.segment_conv1 = nn.Sequential(*self._conv_block(in_nc, ngf, norm_layer, use_selu, num_block=2))
24
- self.side_segment_conv1 = nn.Conv2d(ngf, out_nc, kernel_size=1, stride=1, bias=False)
25
-
26
- self.segment_conv2 = nn.Sequential(*self._conv_block(ngf, ngf*2, norm_layer, use_selu, num_block=2))
27
- self.side_segment_conv2 = nn.Conv2d(ngf*2, out_nc, kernel_size=1, stride=1, bias=False)
28
-
29
- self.segment_conv3 = nn.Sequential(*self._conv_block(ngf*2, ngf*4, norm_layer, use_selu, num_block=3))
30
- self.side_segment_conv3 = nn.Conv2d(ngf*4, out_nc, kernel_size=1, stride=1, bias=False)
31
-
32
- self.segment_conv4 = nn.Sequential(*self._conv_block(ngf*4, ngf*8, norm_layer, use_selu, num_block=3))
33
- self.side_segment_conv4 = nn.Conv2d(ngf*8, out_nc, kernel_size=1, stride=1, bias=False)
34
-
35
- self.segment_conv5 = nn.Sequential(*self._conv_block(ngf*8, ngf*8, norm_layer, use_selu, num_block=3))
36
- self.side_segment_conv5 = nn.Conv2d(ngf*8, out_nc, kernel_size=1, stride=1, bias=False)
37
-
38
- self.fuse_segment_conv = nn.Conv2d(out_nc*5, out_nc, kernel_size=1, stride=1, bias=False)
39
-
40
- ngf2 = ngf//2
41
- #------------road edge detection------------#
42
- self.edge_conv1 = nn.Sequential(*self._conv_block(in_nc+out_nc, ngf2, norm_layer, use_selu, num_block=2))
43
- self.side_edge_conv1 = nn.Conv2d(ngf2, out_nc, kernel_size=1, stride=1, bias=False)
44
-
45
- self.edge_conv2 = nn.Sequential(*self._conv_block(ngf2, ngf2*2, norm_layer, use_selu, num_block=2))
46
- self.side_edge_conv2 = nn.Conv2d(ngf2*2, out_nc, kernel_size=1, stride=1, bias=False)
47
-
48
- self.edge_conv3 = nn.Sequential(*self._conv_block(ngf2*2, ngf2*4, norm_layer, use_selu, num_block=2))
49
- self.side_edge_conv3 = nn.Conv2d(ngf2*4, out_nc, kernel_size=1, stride=1, bias=False)
50
-
51
- self.edge_conv4 = nn.Sequential(*self._conv_block(ngf2*4, ngf2*8, norm_layer, use_selu, num_block=2))
52
- self.side_edge_conv4 = nn.Conv2d(ngf2*8, out_nc, kernel_size=1, stride=1, bias=False)
53
-
54
- self.fuse_edge_conv = nn.Conv2d(out_nc*4, out_nc, kernel_size=1, stride=1, bias=False)
55
-
56
- #------------road centerline extraction------------#
57
- self.centerline_conv1 = nn.Sequential(*self._conv_block(in_nc+out_nc, ngf2, norm_layer, use_selu, num_block=2))
58
- self.side_centerline_conv1 = nn.Conv2d(ngf2, out_nc, kernel_size=1, stride=1, bias=False)
59
-
60
- self.centerline_conv2 = nn.Sequential(*self._conv_block(ngf2, ngf2*2, norm_layer, use_selu, num_block=2))
61
- self.side_centerline_conv2 = nn.Conv2d(ngf2*2, out_nc, kernel_size=1, stride=1, bias=False)
62
-
63
- self.centerline_conv3 = nn.Sequential(*self._conv_block(ngf2*2, ngf2*4, norm_layer, use_selu, num_block=2))
64
- self.side_centerline_conv3 = nn.Conv2d(ngf2*4, out_nc, kernel_size=1, stride=1, bias=False)
65
-
66
- self.centerline_conv4 = nn.Sequential(*self._conv_block(ngf2*4, ngf2*8, norm_layer, use_selu, num_block=2))
67
- self.side_centerline_conv4 = nn.Conv2d(ngf2*8, out_nc, kernel_size=1, stride=1, bias=False)
68
-
69
- self.fuse_centerline_conv = nn.Conv2d(out_nc*4, out_nc, kernel_size=1, stride=1, bias=False)
70
-
71
- self.maxpool = nn.MaxPool2d(2, stride=2)
72
-
73
- #self.up2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
74
- #self.up4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
75
- #self.up8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
76
- #self.up16 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True)
77
-
78
- def _conv_block(self, in_nc, out_nc, norm_layer, use_selu, num_block=2, kernel_size=3,
79
- stride=1, padding=1, bias=True):
80
- conv = []
81
- for i in range(num_block):
82
- cur_in_nc = in_nc if i == 0 else out_nc
83
- conv += [nn.Conv2d(cur_in_nc, out_nc, kernel_size=kernel_size, stride=stride,
84
- padding=padding, bias=bias)]
85
- if use_selu:
86
- conv += [nn.SeLU(True)]
87
- else:
88
- conv += [norm_layer(out_nc), nn.ReLU(True)]
89
- return conv
90
-
91
- def _segment_forward(self, x):
92
- """
93
- predict road surface segmentation
94
- :param: x, image tensor, [N, C, H, W]
95
- """
96
- h,w = x.size()[2:]
97
- # main stream features
98
- conv1 = self.segment_conv1(x)
99
- conv2 = self.segment_conv2(self.maxpool(conv1))
100
- conv3 = self.segment_conv3(self.maxpool(conv2))
101
- conv4 = self.segment_conv4(self.maxpool(conv3))
102
- conv5 = self.segment_conv5(self.maxpool(conv4))
103
- # side output features
104
- side_output1 = self.side_segment_conv1(conv1)
105
- side_output2 = self.side_segment_conv2(conv2)
106
- side_output3 = self.side_segment_conv3(conv3)
107
- side_output4 = self.side_segment_conv4(conv4)
108
- side_output5 = self.side_segment_conv5(conv5)
109
- # upsampling side output features
110
- side_output2 = F.interpolate(side_output2, size=(h, w), mode='bilinear', align_corners=True) #self.up2(side_output2)
111
- side_output3 = F.interpolate(side_output3, size=(h, w), mode='bilinear', align_corners=True) #self.up4(side_output3)
112
- side_output4 = F.interpolate(side_output4, size=(h, w), mode='bilinear', align_corners=True) #self.up8(side_output4)
113
- side_output5 = F.interpolate(side_output5, size=(h, w), mode='bilinear', align_corners=True) #self.up16(side_output5)
114
-
115
- fused = self.fuse_segment_conv(torch.cat([
116
- side_output1,
117
- side_output2,
118
- side_output3,
119
- side_output4,
120
- side_output5], dim=1))
121
- return [side_output1, side_output2, side_output3, side_output4, side_output5, fused]
122
-
123
- def _edge_forward(self, x):
124
- """
125
- predict road edge
126
- :param: x, [image tensor, predicted segmentation tensor], [N, C+1, H, W]
127
- """
128
- h, w = x.size()[2:]
129
- # main stream features
130
- conv1 = self.edge_conv1(x)
131
- conv2 = self.edge_conv2(self.maxpool(conv1))
132
- conv3 = self.edge_conv3(self.maxpool(conv2))
133
- conv4 = self.edge_conv4(self.maxpool(conv3))
134
- # side output features
135
- side_output1 = self.side_edge_conv1(conv1)
136
- side_output2 = self.side_edge_conv2(conv2)
137
- side_output3 = self.side_edge_conv3(conv3)
138
- side_output4 = self.side_edge_conv4(conv4)
139
- # upsampling side output features
140
- side_output2 = F.interpolate(side_output2, size=(h, w), mode='bilinear', align_corners=True) #self.up2(side_output2)
141
- side_output3 = F.interpolate(side_output3, size=(h, w), mode='bilinear', align_corners=True) #self.up4(side_output3)
142
- side_output4 = F.interpolate(side_output4, size=(h, w), mode='bilinear', align_corners=True) #self.up8(side_output4)
143
- fused = self.fuse_edge_conv(torch.cat([
144
- side_output1,
145
- side_output2,
146
- side_output3,
147
- side_output4], dim=1))
148
- return [side_output1, side_output2, side_output3, side_output4, fused]
149
-
150
- def _centerline_forward(self, x):
151
- """
152
- predict road edge
153
- :param: x, [image tensor, predicted segmentation tensor], [N, C+1, H, W]
154
- """
155
- h,w = x.size()[2:]
156
- # main stream features
157
- conv1 = self.centerline_conv1(x)
158
- conv2 = self.centerline_conv2(self.maxpool(conv1))
159
- conv3 = self.centerline_conv3(self.maxpool(conv2))
160
- conv4 = self.centerline_conv4(self.maxpool(conv3))
161
- # side output features
162
- side_output1 = self.side_centerline_conv1(conv1)
163
- side_output2 = self.side_centerline_conv2(conv2)
164
- side_output3 = self.side_centerline_conv3(conv3)
165
- side_output4 = self.side_centerline_conv4(conv4)
166
- # upsampling side output features
167
- side_output2 = F.interpolate(side_output2, size=(h, w), mode='bilinear', align_corners=True) #self.up2(side_output2)
168
- side_output3 = F.interpolate(side_output3, size=(h, w), mode='bilinear', align_corners=True) #self.up4(side_output3)
169
- side_output4 = F.interpolate(side_output4, size=(h, w), mode='bilinear', align_corners=True) #self.up8(side_output4)
170
- fused = self.fuse_centerline_conv(torch.cat([
171
- side_output1,
172
- side_output2,
173
- side_output3,
174
- side_output4], dim=1))
175
- return [side_output1, side_output2, side_output3, side_output4, fused]
176
-
177
- def forward(self, x):
178
- segments = self._segment_forward(x)
179
-
180
- x_ = torch.cat([x, segments[-1]], dim=1)
181
- edges = self._edge_forward(x_)
182
- centerlines = self._centerline_forward(x_)
183
- return segments, edges, centerlines
184
-
185
- def define_roadnet(in_nc,
186
- out_nc,
187
- ngf,
188
- norm='batch',
189
- use_selu=1,
190
- init_type='xavier',
191
- init_gain=0.02,
192
- gpu_ids=[]):
193
- net = RoadNet(in_nc, out_nc, ngf, norm, use_selu)
194
- return init_net(net, init_type, init_gain, gpu_ids)