Spaces:
Running
Running
Upload 5 files
Browse files- models/base_model.py +229 -0
- models/deepcrack_networks.py +110 -0
- models/networks.py +609 -0
- models/roadnet_model.py +120 -0
- models/roadnet_networks.py +194 -0
models/base_model.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
from . import networks
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BaseModel(ABC):
|
| 9 |
+
"""This class is an abstract base class (ABC) for models.
|
| 10 |
+
To create a subclass, you need to implement the following five functions:
|
| 11 |
+
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
| 12 |
+
-- <set_input>: unpack data from dataset and apply preprocessing.
|
| 13 |
+
-- <forward>: produce intermediate results.
|
| 14 |
+
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
|
| 15 |
+
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, opt):
|
| 19 |
+
"""Initialize the BaseModel class.
|
| 20 |
+
|
| 21 |
+
Parameters:
|
| 22 |
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
| 23 |
+
|
| 24 |
+
When creating your custom class, you need to implement your own initialization.
|
| 25 |
+
In this fucntion, you should first call <BaseModel.__init__(self, opt)>
|
| 26 |
+
Then, you need to define four lists:
|
| 27 |
+
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
| 28 |
+
-- self.model_names (str list): specify the images that you want to display and save.
|
| 29 |
+
-- self.visual_names (str list): define networks used in our training.
|
| 30 |
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
|
| 31 |
+
"""
|
| 32 |
+
self.opt = opt
|
| 33 |
+
self.gpu_ids = opt.gpu_ids
|
| 34 |
+
self.isTrain = opt.isTrain
|
| 35 |
+
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
|
| 36 |
+
if hasattr(opt, 'checkpoints_dir'):
|
| 37 |
+
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
|
| 38 |
+
if not hasattr(opt, 'preprocess') or opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
|
| 39 |
+
torch.backends.cudnn.benchmark = True
|
| 40 |
+
self.loss_names = []
|
| 41 |
+
self.model_names = []
|
| 42 |
+
self.visual_names = []
|
| 43 |
+
self.optimizers = []
|
| 44 |
+
self.image_paths = []
|
| 45 |
+
self.metric = 0 # used for learning rate policy 'plateau'
|
| 46 |
+
|
| 47 |
+
@staticmethod
|
| 48 |
+
def modify_commandline_options(parser, is_train):
|
| 49 |
+
"""Add new model-specific options, and rewrite default values for existing options.
|
| 50 |
+
|
| 51 |
+
Parameters:
|
| 52 |
+
parser -- original option parser
|
| 53 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
the modified parser.
|
| 57 |
+
"""
|
| 58 |
+
return parser
|
| 59 |
+
|
| 60 |
+
@abstractmethod
|
| 61 |
+
def set_input(self, input):
|
| 62 |
+
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
| 63 |
+
|
| 64 |
+
Parameters:
|
| 65 |
+
input (dict): includes the data itself and its metadata information.
|
| 66 |
+
"""
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
@abstractmethod
|
| 70 |
+
def forward(self):
|
| 71 |
+
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
| 72 |
+
pass
|
| 73 |
+
|
| 74 |
+
@abstractmethod
|
| 75 |
+
def optimize_parameters(self):
|
| 76 |
+
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
def setup(self, opt):
|
| 80 |
+
"""Load and print networks; create schedulers
|
| 81 |
+
|
| 82 |
+
Parameters:
|
| 83 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
| 84 |
+
"""
|
| 85 |
+
if self.isTrain:
|
| 86 |
+
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
|
| 87 |
+
if not self.isTrain or opt.continue_train:
|
| 88 |
+
load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
|
| 89 |
+
self.load_networks(load_suffix)
|
| 90 |
+
self.print_networks(opt.verbose)
|
| 91 |
+
|
| 92 |
+
def eval(self):
|
| 93 |
+
"""Make models eval mode during test time"""
|
| 94 |
+
for name in self.model_names:
|
| 95 |
+
if isinstance(name, str):
|
| 96 |
+
net = getattr(self, 'net' + name)
|
| 97 |
+
net.eval()
|
| 98 |
+
|
| 99 |
+
def test(self):
|
| 100 |
+
"""Forward function used in test time.
|
| 101 |
+
|
| 102 |
+
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
|
| 103 |
+
It also calls <compute_visuals> to produce additional visualization results
|
| 104 |
+
"""
|
| 105 |
+
with torch.no_grad():
|
| 106 |
+
self.forward()
|
| 107 |
+
self.compute_visuals()
|
| 108 |
+
|
| 109 |
+
def compute_visuals(self):
|
| 110 |
+
"""Calculate additional output images for visdom and HTML visualization"""
|
| 111 |
+
pass
|
| 112 |
+
|
| 113 |
+
def get_image_paths(self):
|
| 114 |
+
""" Return image paths that are used to load current data"""
|
| 115 |
+
return self.image_paths
|
| 116 |
+
|
| 117 |
+
def update_learning_rate(self):
|
| 118 |
+
"""Update learning rates for all the networks; called at the end of every epoch"""
|
| 119 |
+
for scheduler in self.schedulers:
|
| 120 |
+
if self.opt.lr_policy == 'plateau':
|
| 121 |
+
scheduler.step(self.metric)
|
| 122 |
+
else:
|
| 123 |
+
scheduler.step()
|
| 124 |
+
lr = self.optimizers[0].param_groups[0]['lr']
|
| 125 |
+
print('learning rate = %.7f' % lr)
|
| 126 |
+
|
| 127 |
+
def get_current_visuals(self):
|
| 128 |
+
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
|
| 129 |
+
visual_ret = OrderedDict()
|
| 130 |
+
for name in self.visual_names:
|
| 131 |
+
if isinstance(name, str):
|
| 132 |
+
visual_ret[name] = getattr(self, name)
|
| 133 |
+
return visual_ret
|
| 134 |
+
|
| 135 |
+
def get_current_losses(self):
|
| 136 |
+
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
|
| 137 |
+
errors_ret = OrderedDict()
|
| 138 |
+
for name in self.loss_names:
|
| 139 |
+
if isinstance(name, str):
|
| 140 |
+
errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
|
| 141 |
+
return errors_ret
|
| 142 |
+
|
| 143 |
+
def save_networks(self, epoch):
|
| 144 |
+
"""Save all the networks to the disk.
|
| 145 |
+
|
| 146 |
+
Parameters:
|
| 147 |
+
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
| 148 |
+
"""
|
| 149 |
+
for name in self.model_names:
|
| 150 |
+
if isinstance(name, str):
|
| 151 |
+
save_filename = '%s_net_%s.pth' % (epoch, name)
|
| 152 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
| 153 |
+
net = getattr(self, 'net' + name)
|
| 154 |
+
|
| 155 |
+
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
|
| 156 |
+
torch.save(net.module.cpu().state_dict(), save_path)
|
| 157 |
+
net.cuda(self.gpu_ids[0])
|
| 158 |
+
else:
|
| 159 |
+
torch.save(net.cpu().state_dict(), save_path)
|
| 160 |
+
|
| 161 |
+
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
|
| 162 |
+
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
|
| 163 |
+
key = keys[i]
|
| 164 |
+
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
|
| 165 |
+
if module.__class__.__name__.startswith('InstanceNorm') and \
|
| 166 |
+
(key == 'running_mean' or key == 'running_var'):
|
| 167 |
+
if getattr(module, key) is None:
|
| 168 |
+
state_dict.pop('.'.join(keys))
|
| 169 |
+
if module.__class__.__name__.startswith('InstanceNorm') and \
|
| 170 |
+
(key == 'num_batches_tracked'):
|
| 171 |
+
state_dict.pop('.'.join(keys))
|
| 172 |
+
else:
|
| 173 |
+
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
|
| 174 |
+
|
| 175 |
+
def load_networks(self, epoch):
|
| 176 |
+
"""Load all the networks from the disk.
|
| 177 |
+
|
| 178 |
+
Parameters:
|
| 179 |
+
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
| 180 |
+
"""
|
| 181 |
+
for name in self.model_names:
|
| 182 |
+
if isinstance(name, str):
|
| 183 |
+
load_filename = '%s_net_%s.pth' % (epoch, name)
|
| 184 |
+
load_path = os.path.join(self.save_dir, load_filename)
|
| 185 |
+
net = getattr(self, 'net' + name)
|
| 186 |
+
if isinstance(net, torch.nn.DataParallel):
|
| 187 |
+
net = net.module
|
| 188 |
+
print('loading the model from %s' % load_path)
|
| 189 |
+
# if you are using PyTorch newer than 0.4 (e.g., built from
|
| 190 |
+
# GitHub source), you can remove str() on self.device
|
| 191 |
+
state_dict = torch.load(load_path, map_location=str(self.device))
|
| 192 |
+
if hasattr(state_dict, '_metadata'):
|
| 193 |
+
del state_dict._metadata
|
| 194 |
+
|
| 195 |
+
# patch InstanceNorm checkpoints prior to 0.4
|
| 196 |
+
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
| 197 |
+
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
|
| 198 |
+
net.load_state_dict(state_dict, strict=False)
|
| 199 |
+
|
| 200 |
+
def print_networks(self, verbose):
|
| 201 |
+
"""Print the total number of parameters in the network and (if verbose) network architecture
|
| 202 |
+
|
| 203 |
+
Parameters:
|
| 204 |
+
verbose (bool) -- if verbose: print the network architecture
|
| 205 |
+
"""
|
| 206 |
+
print('---------- Networks initialized -------------')
|
| 207 |
+
for name in self.model_names:
|
| 208 |
+
if isinstance(name, str):
|
| 209 |
+
net = getattr(self, 'net' + name)
|
| 210 |
+
num_params = 0
|
| 211 |
+
for param in net.parameters():
|
| 212 |
+
num_params += param.numel()
|
| 213 |
+
if verbose:
|
| 214 |
+
print(net)
|
| 215 |
+
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
|
| 216 |
+
print('-----------------------------------------------')
|
| 217 |
+
|
| 218 |
+
def set_requires_grad(self, nets, requires_grad=False):
|
| 219 |
+
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
|
| 220 |
+
Parameters:
|
| 221 |
+
nets (network list) -- a list of networks
|
| 222 |
+
requires_grad (bool) -- whether the networks require gradients or not
|
| 223 |
+
"""
|
| 224 |
+
if not isinstance(nets, list):
|
| 225 |
+
nets = [nets]
|
| 226 |
+
for net in nets:
|
| 227 |
+
if net is not None:
|
| 228 |
+
for param in net.parameters():
|
| 229 |
+
param.requires_grad = requires_grad
|
models/deepcrack_networks.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#! -*- coding: utf-8 -*-
|
| 2 |
+
# Author: Yahui Liu <yahui.liu@unitn.it>
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Reference:
|
| 6 |
+
|
| 7 |
+
DeepCrack: A deep hierarchical feature learning architecture for crack segmentation.
|
| 8 |
+
https://www.sciencedirect.com/science/article/pii/S0925231219300566
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from .networks import get_norm_layer, init_net
|
| 15 |
+
|
| 16 |
+
class DeepCrackNet(nn.Module):
|
| 17 |
+
def __init__(self, in_nc, num_classes, ngf, norm='batch'):
|
| 18 |
+
super(DeepCrackNet, self).__init__()
|
| 19 |
+
|
| 20 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
| 21 |
+
self.conv1 = nn.Sequential(*self._conv_block(in_nc, ngf, norm_layer, num_block=2))
|
| 22 |
+
self.side_conv1 = nn.Conv2d(ngf, num_classes, kernel_size=1, stride=1, bias=False)
|
| 23 |
+
|
| 24 |
+
self.conv2 = nn.Sequential(*self._conv_block(ngf, ngf*2, norm_layer, num_block=2))
|
| 25 |
+
self.side_conv2 = nn.Conv2d(ngf*2, num_classes, kernel_size=1, stride=1, bias=False)
|
| 26 |
+
|
| 27 |
+
self.conv3 = nn.Sequential(*self._conv_block(ngf*2, ngf*4, norm_layer, num_block=3))
|
| 28 |
+
self.side_conv3 = nn.Conv2d(ngf*4, num_classes, kernel_size=1, stride=1, bias=False)
|
| 29 |
+
|
| 30 |
+
self.conv4 = nn.Sequential(*self._conv_block(ngf*4, ngf*8, norm_layer, num_block=3))
|
| 31 |
+
self.side_conv4 = nn.Conv2d(ngf*8, num_classes, kernel_size=1, stride=1, bias=False)
|
| 32 |
+
|
| 33 |
+
self.conv5 = nn.Sequential(*self._conv_block(ngf*8, ngf*8, norm_layer, num_block=3))
|
| 34 |
+
self.side_conv5 = nn.Conv2d(ngf*8, num_classes, kernel_size=1, stride=1, bias=False)
|
| 35 |
+
|
| 36 |
+
self.fuse_conv = nn.Conv2d(num_classes*5, num_classes, kernel_size=1, stride=1, bias=False)
|
| 37 |
+
self.maxpool = nn.MaxPool2d(2, stride=2)
|
| 38 |
+
|
| 39 |
+
#self.up2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
| 40 |
+
#self.up4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
|
| 41 |
+
#self.up8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
|
| 42 |
+
#self.up16 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True)
|
| 43 |
+
|
| 44 |
+
def _conv_block(self, in_nc, out_nc, norm_layer, num_block=2, kernel_size=3,
|
| 45 |
+
stride=1, padding=1, bias=False):
|
| 46 |
+
conv = []
|
| 47 |
+
for i in range(num_block):
|
| 48 |
+
cur_in_nc = in_nc if i == 0 else out_nc
|
| 49 |
+
conv += [nn.Conv2d(cur_in_nc, out_nc, kernel_size=kernel_size, stride=stride,
|
| 50 |
+
padding=padding, bias=bias),
|
| 51 |
+
norm_layer(out_nc),
|
| 52 |
+
nn.ReLU(True)]
|
| 53 |
+
return conv
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
h,w = x.size()[2:]
|
| 57 |
+
# main stream features
|
| 58 |
+
conv1 = self.conv1(x)
|
| 59 |
+
conv2 = self.conv2(self.maxpool(conv1))
|
| 60 |
+
conv3 = self.conv3(self.maxpool(conv2))
|
| 61 |
+
conv4 = self.conv4(self.maxpool(conv3))
|
| 62 |
+
conv5 = self.conv5(self.maxpool(conv4))
|
| 63 |
+
# side output features
|
| 64 |
+
side_output1 = self.side_conv1(conv1)
|
| 65 |
+
side_output2 = self.side_conv2(conv2)
|
| 66 |
+
side_output3 = self.side_conv3(conv3)
|
| 67 |
+
side_output4 = self.side_conv4(conv4)
|
| 68 |
+
side_output5 = self.side_conv5(conv5)
|
| 69 |
+
# upsampling side output features
|
| 70 |
+
side_output2 = F.interpolate(side_output2, size=(h, w), mode='bilinear', align_corners=True) #self.up2(side_output2)
|
| 71 |
+
side_output3 = F.interpolate(side_output3, size=(h, w), mode='bilinear', align_corners=True) #self.up4(side_output3)
|
| 72 |
+
side_output4 = F.interpolate(side_output4, size=(h, w), mode='bilinear', align_corners=True) #self.up8(side_output4)
|
| 73 |
+
side_output5 = F.interpolate(side_output5, size=(h, w), mode='bilinear', align_corners=True) #self.up16(side_output5)
|
| 74 |
+
|
| 75 |
+
fused = self.fuse_conv(torch.cat([side_output1,
|
| 76 |
+
side_output2,
|
| 77 |
+
side_output3,
|
| 78 |
+
side_output4,
|
| 79 |
+
side_output5], dim=1))
|
| 80 |
+
return side_output1, side_output2, side_output3, side_output4, side_output5, fused
|
| 81 |
+
|
| 82 |
+
def define_deepcrack(in_nc,
|
| 83 |
+
num_classes,
|
| 84 |
+
ngf,
|
| 85 |
+
norm='batch',
|
| 86 |
+
init_type='xavier',
|
| 87 |
+
init_gain=0.02,
|
| 88 |
+
gpu_ids=[]):
|
| 89 |
+
net = DeepCrackNet(in_nc, num_classes, ngf, norm)
|
| 90 |
+
return init_net(net, init_type, init_gain, gpu_ids)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class BinaryFocalLoss(nn.Module):
|
| 94 |
+
def __init__(self, alpha=1, gamma=2, logits=False, size_average=True):
|
| 95 |
+
super(BinaryFocalLoss, self).__init__()
|
| 96 |
+
self.alpha = alpha
|
| 97 |
+
self.gamma = gamma
|
| 98 |
+
self.logits = logits
|
| 99 |
+
self.size_average = size_average
|
| 100 |
+
self.criterion = nn.BCEWithLogitsLoss(reduction='none')
|
| 101 |
+
|
| 102 |
+
def forward(self, inputs, targets):
|
| 103 |
+
BCE_loss = self.criterion(inputs, targets)
|
| 104 |
+
pt = torch.exp(-BCE_loss)
|
| 105 |
+
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
|
| 106 |
+
|
| 107 |
+
if self.size_average:
|
| 108 |
+
return F_loss.mean()
|
| 109 |
+
else:
|
| 110 |
+
return F_loss.sum()
|
models/networks.py
ADDED
|
@@ -0,0 +1,609 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn import init
|
| 4 |
+
import functools
|
| 5 |
+
from torch.optim import lr_scheduler
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
###############################################################################
|
| 9 |
+
# Helper Functions
|
| 10 |
+
###############################################################################
|
| 11 |
+
def get_norm_layer(norm_type='instance'):
|
| 12 |
+
"""Return a normalization layer
|
| 13 |
+
|
| 14 |
+
Parameters:
|
| 15 |
+
norm_type (str) -- the name of the normalization layer: batch | instance | none
|
| 16 |
+
|
| 17 |
+
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
|
| 18 |
+
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
|
| 19 |
+
"""
|
| 20 |
+
if norm_type == 'batch':
|
| 21 |
+
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
| 22 |
+
elif norm_type == 'instance':
|
| 23 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
| 24 |
+
elif norm_type == 'none':
|
| 25 |
+
norm_layer = None
|
| 26 |
+
else:
|
| 27 |
+
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
| 28 |
+
return norm_layer
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_scheduler(optimizer, opt):
|
| 32 |
+
"""Return a learning rate scheduler
|
| 33 |
+
|
| 34 |
+
Parameters:
|
| 35 |
+
optimizer -- the optimizer of the network
|
| 36 |
+
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
|
| 37 |
+
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
|
| 38 |
+
|
| 39 |
+
For 'linear', we keep the same learning rate for the first <opt.niter> epochs
|
| 40 |
+
and linearly decay the rate to zero over the next <opt.niter_decay> epochs.
|
| 41 |
+
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
|
| 42 |
+
See https://pytorch.org/docs/stable/optim.html for more details.
|
| 43 |
+
"""
|
| 44 |
+
if opt.lr_policy == 'linear':
|
| 45 |
+
def lambda_rule(epoch):
|
| 46 |
+
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
|
| 47 |
+
return lr_l
|
| 48 |
+
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
| 49 |
+
elif opt.lr_policy == 'step':
|
| 50 |
+
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
|
| 51 |
+
elif opt.lr_policy == 'plateau':
|
| 52 |
+
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
|
| 53 |
+
elif opt.lr_policy == 'cosine':
|
| 54 |
+
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
|
| 55 |
+
else:
|
| 56 |
+
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
|
| 57 |
+
return scheduler
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def init_weights(net, init_type='normal', init_gain=0.02):
|
| 61 |
+
"""Initialize network weights.
|
| 62 |
+
|
| 63 |
+
Parameters:
|
| 64 |
+
net (network) -- network to be initialized
|
| 65 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
| 66 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
| 67 |
+
|
| 68 |
+
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
|
| 69 |
+
work better for some applications. Feel free to try yourself.
|
| 70 |
+
"""
|
| 71 |
+
def init_func(m): # define the initialization function
|
| 72 |
+
classname = m.__class__.__name__
|
| 73 |
+
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
| 74 |
+
if init_type == 'normal':
|
| 75 |
+
init.normal_(m.weight.data, 0.0, init_gain)
|
| 76 |
+
elif init_type == 'xavier':
|
| 77 |
+
init.xavier_normal_(m.weight.data, gain=init_gain)
|
| 78 |
+
elif init_type == 'kaiming':
|
| 79 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
| 80 |
+
elif init_type == 'orthogonal':
|
| 81 |
+
init.orthogonal_(m.weight.data, gain=init_gain)
|
| 82 |
+
else:
|
| 83 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
| 84 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
| 85 |
+
init.constant_(m.bias.data, 0.0)
|
| 86 |
+
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
|
| 87 |
+
init.normal_(m.weight.data, 1.0, init_gain)
|
| 88 |
+
init.constant_(m.bias.data, 0.0)
|
| 89 |
+
|
| 90 |
+
print('initialize network with %s' % init_type)
|
| 91 |
+
net.apply(init_func) # apply the initialization function <init_func>
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
|
| 95 |
+
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
|
| 96 |
+
Parameters:
|
| 97 |
+
net (network) -- the network to be initialized
|
| 98 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
| 99 |
+
gain (float) -- scaling factor for normal, xavier and orthogonal.
|
| 100 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
| 101 |
+
|
| 102 |
+
Return an initialized network.
|
| 103 |
+
"""
|
| 104 |
+
if len(gpu_ids) > 0:
|
| 105 |
+
assert(torch.cuda.is_available())
|
| 106 |
+
net.to(gpu_ids[0])
|
| 107 |
+
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
|
| 108 |
+
init_weights(net, init_type, init_gain=init_gain)
|
| 109 |
+
return net
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
|
| 113 |
+
"""Create a generator
|
| 114 |
+
|
| 115 |
+
Parameters:
|
| 116 |
+
input_nc (int) -- the number of channels in input images
|
| 117 |
+
output_nc (int) -- the number of channels in output images
|
| 118 |
+
ngf (int) -- the number of filters in the last conv layer
|
| 119 |
+
netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
|
| 120 |
+
norm (str) -- the name of normalization layers used in the network: batch | instance | none
|
| 121 |
+
use_dropout (bool) -- if use dropout layers.
|
| 122 |
+
init_type (str) -- the name of our initialization method.
|
| 123 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
| 124 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
| 125 |
+
|
| 126 |
+
Returns a generator
|
| 127 |
+
|
| 128 |
+
Our current implementation provides two types of generators:
|
| 129 |
+
U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
|
| 130 |
+
The original U-Net paper: https://arxiv.org/abs/1505.04597
|
| 131 |
+
|
| 132 |
+
Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
|
| 133 |
+
Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
|
| 134 |
+
We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
The generator has been initialized by <init_net>. It uses RELU for non-linearity.
|
| 138 |
+
"""
|
| 139 |
+
net = None
|
| 140 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
| 141 |
+
|
| 142 |
+
if netG == 'resnet_9blocks':
|
| 143 |
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
|
| 144 |
+
elif netG == 'resnet_6blocks':
|
| 145 |
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
|
| 146 |
+
elif netG == 'unet_128':
|
| 147 |
+
net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
| 148 |
+
elif netG == 'unet_256':
|
| 149 |
+
net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
| 150 |
+
else:
|
| 151 |
+
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
|
| 152 |
+
return init_net(net, init_type, init_gain, gpu_ids)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
|
| 156 |
+
"""Create a discriminator
|
| 157 |
+
|
| 158 |
+
Parameters:
|
| 159 |
+
input_nc (int) -- the number of channels in input images
|
| 160 |
+
ndf (int) -- the number of filters in the first conv layer
|
| 161 |
+
netD (str) -- the architecture's name: basic | n_layers | pixel
|
| 162 |
+
n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
|
| 163 |
+
norm (str) -- the type of normalization layers used in the network.
|
| 164 |
+
init_type (str) -- the name of the initialization method.
|
| 165 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
| 166 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
| 167 |
+
|
| 168 |
+
Returns a discriminator
|
| 169 |
+
|
| 170 |
+
Our current implementation provides three types of discriminators:
|
| 171 |
+
[basic]: 'PatchGAN' classifier described in the original pix2pix paper.
|
| 172 |
+
It can classify whether 70×70 overlapping patches are real or fake.
|
| 173 |
+
Such a patch-level discriminator architecture has fewer parameters
|
| 174 |
+
than a full-image discriminator and can work on arbitrarily-sized images
|
| 175 |
+
in a fully convolutional fashion.
|
| 176 |
+
|
| 177 |
+
[n_layers]: With this mode, you cna specify the number of conv layers in the discriminator
|
| 178 |
+
with the parameter <n_layers_D> (default=3 as used in [basic] (PatchGAN).)
|
| 179 |
+
|
| 180 |
+
[pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
|
| 181 |
+
It encourages greater color diversity but has no effect on spatial statistics.
|
| 182 |
+
|
| 183 |
+
The discriminator has been initialized by <init_net>. It uses Leakly RELU for non-linearity.
|
| 184 |
+
"""
|
| 185 |
+
net = None
|
| 186 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
| 187 |
+
|
| 188 |
+
if netD == 'basic': # default PatchGAN classifier
|
| 189 |
+
net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
|
| 190 |
+
elif netD == 'n_layers': # more options
|
| 191 |
+
net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
|
| 192 |
+
elif netD == 'pixel': # classify if each pixel is real or fake
|
| 193 |
+
net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
|
| 194 |
+
else:
|
| 195 |
+
raise NotImplementedError('Discriminator model name [%s] is not recognized' % net)
|
| 196 |
+
return init_net(net, init_type, init_gain, gpu_ids)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
##############################################################################
|
| 200 |
+
# Classes
|
| 201 |
+
##############################################################################
|
| 202 |
+
class GANLoss(nn.Module):
|
| 203 |
+
"""Define different GAN objectives.
|
| 204 |
+
|
| 205 |
+
The GANLoss class abstracts away the need to create the target label tensor
|
| 206 |
+
that has the same size as the input.
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
|
| 210 |
+
""" Initialize the GANLoss class.
|
| 211 |
+
|
| 212 |
+
Parameters:
|
| 213 |
+
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
|
| 214 |
+
target_real_label (bool) - - label for a real image
|
| 215 |
+
target_fake_label (bool) - - label of a fake image
|
| 216 |
+
|
| 217 |
+
Note: Do not use sigmoid as the last layer of Discriminator.
|
| 218 |
+
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
|
| 219 |
+
"""
|
| 220 |
+
super(GANLoss, self).__init__()
|
| 221 |
+
self.register_buffer('real_label', torch.tensor(target_real_label))
|
| 222 |
+
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
| 223 |
+
self.gan_mode = gan_mode
|
| 224 |
+
if gan_mode == 'lsgan':
|
| 225 |
+
self.loss = nn.MSELoss()
|
| 226 |
+
elif gan_mode == 'vanilla':
|
| 227 |
+
self.loss = nn.BCEWithLogitsLoss()
|
| 228 |
+
elif gan_mode in ['wgangp']:
|
| 229 |
+
self.loss = None
|
| 230 |
+
else:
|
| 231 |
+
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
|
| 232 |
+
|
| 233 |
+
def get_target_tensor(self, prediction, target_is_real):
|
| 234 |
+
"""Create label tensors with the same size as the input.
|
| 235 |
+
|
| 236 |
+
Parameters:
|
| 237 |
+
prediction (tensor) - - tpyically the prediction from a discriminator
|
| 238 |
+
target_is_real (bool) - - if the ground truth label is for real images or fake images
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
A label tensor filled with ground truth label, and with the size of the input
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
if target_is_real:
|
| 245 |
+
target_tensor = self.real_label
|
| 246 |
+
else:
|
| 247 |
+
target_tensor = self.fake_label
|
| 248 |
+
return target_tensor.expand_as(prediction)
|
| 249 |
+
|
| 250 |
+
def __call__(self, prediction, target_is_real):
|
| 251 |
+
"""Calculate loss given Discriminator's output and grount truth labels.
|
| 252 |
+
|
| 253 |
+
Parameters:
|
| 254 |
+
prediction (tensor) - - tpyically the prediction output from a discriminator
|
| 255 |
+
target_is_real (bool) - - if the ground truth label is for real images or fake images
|
| 256 |
+
|
| 257 |
+
Returns:
|
| 258 |
+
the calculated loss.
|
| 259 |
+
"""
|
| 260 |
+
if self.gan_mode in ['lsgan', 'vanilla']:
|
| 261 |
+
target_tensor = self.get_target_tensor(prediction, target_is_real)
|
| 262 |
+
loss = self.loss(prediction, target_tensor)
|
| 263 |
+
elif self.gan_mode == 'wgangp':
|
| 264 |
+
if target_is_real:
|
| 265 |
+
loss = -prediction.mean()
|
| 266 |
+
else:
|
| 267 |
+
loss = prediction.mean()
|
| 268 |
+
return loss
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
|
| 272 |
+
"""Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
|
| 273 |
+
|
| 274 |
+
Arguments:
|
| 275 |
+
netD (network) -- discriminator network
|
| 276 |
+
real_data (tensor array) -- real images
|
| 277 |
+
fake_data (tensor array) -- generated images from the generator
|
| 278 |
+
device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
|
| 279 |
+
type (str) -- if we mix real and fake data or not [real | fake | mixed].
|
| 280 |
+
constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
|
| 281 |
+
lambda_gp (float) -- weight for this loss
|
| 282 |
+
|
| 283 |
+
Returns the gradient penalty loss
|
| 284 |
+
"""
|
| 285 |
+
if lambda_gp > 0.0:
|
| 286 |
+
if type == 'real': # either use real images, fake images, or a linear interpolation of two.
|
| 287 |
+
interpolatesv = real_data
|
| 288 |
+
elif type == 'fake':
|
| 289 |
+
interpolatesv = fake_data
|
| 290 |
+
elif type == 'mixed':
|
| 291 |
+
alpha = torch.rand(real_data.shape[0], 1)
|
| 292 |
+
alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
|
| 293 |
+
alpha = alpha.to(device)
|
| 294 |
+
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
|
| 295 |
+
else:
|
| 296 |
+
raise NotImplementedError('{} not implemented'.format(type))
|
| 297 |
+
interpolatesv.requires_grad_(True)
|
| 298 |
+
disc_interpolates = netD(interpolatesv)
|
| 299 |
+
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
|
| 300 |
+
grad_outputs=torch.ones(disc_interpolates.size()).to(device),
|
| 301 |
+
create_graph=True, retain_graph=True, only_inputs=True)
|
| 302 |
+
gradients = gradients[0].view(real_data.size(0), -1) # flat the data
|
| 303 |
+
gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
|
| 304 |
+
return gradient_penalty, gradients
|
| 305 |
+
else:
|
| 306 |
+
return 0.0, None
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class ResnetGenerator(nn.Module):
|
| 310 |
+
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
|
| 311 |
+
|
| 312 |
+
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
|
| 313 |
+
"""
|
| 314 |
+
|
| 315 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
|
| 316 |
+
"""Construct a Resnet-based generator
|
| 317 |
+
|
| 318 |
+
Parameters:
|
| 319 |
+
input_nc (int) -- the number of channels in input images
|
| 320 |
+
output_nc (int) -- the number of channels in output images
|
| 321 |
+
ngf (int) -- the number of filters in the last conv layer
|
| 322 |
+
norm_layer -- normalization layer
|
| 323 |
+
use_dropout (bool) -- if use dropout layers
|
| 324 |
+
n_blocks (int) -- the number of ResNet blocks
|
| 325 |
+
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
|
| 326 |
+
"""
|
| 327 |
+
assert(n_blocks >= 0)
|
| 328 |
+
super(ResnetGenerator, self).__init__()
|
| 329 |
+
if type(norm_layer) == functools.partial:
|
| 330 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
| 331 |
+
else:
|
| 332 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
| 333 |
+
|
| 334 |
+
model = [nn.ReflectionPad2d(3),
|
| 335 |
+
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
|
| 336 |
+
norm_layer(ngf),
|
| 337 |
+
nn.ReLU(True)]
|
| 338 |
+
|
| 339 |
+
n_downsampling = 2
|
| 340 |
+
for i in range(n_downsampling): # add downsampling layers
|
| 341 |
+
mult = 2 ** i
|
| 342 |
+
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
|
| 343 |
+
norm_layer(ngf * mult * 2),
|
| 344 |
+
nn.ReLU(True)]
|
| 345 |
+
|
| 346 |
+
mult = 2 ** n_downsampling
|
| 347 |
+
for i in range(n_blocks): # add ResNet blocks
|
| 348 |
+
|
| 349 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
| 350 |
+
|
| 351 |
+
for i in range(n_downsampling): # add upsampling layers
|
| 352 |
+
mult = 2 ** (n_downsampling - i)
|
| 353 |
+
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
| 354 |
+
kernel_size=3, stride=2,
|
| 355 |
+
padding=1, output_padding=1,
|
| 356 |
+
bias=use_bias),
|
| 357 |
+
norm_layer(int(ngf * mult / 2)),
|
| 358 |
+
nn.ReLU(True)]
|
| 359 |
+
model += [nn.ReflectionPad2d(3)]
|
| 360 |
+
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
| 361 |
+
model += [nn.Tanh()]
|
| 362 |
+
|
| 363 |
+
self.model = nn.Sequential(*model)
|
| 364 |
+
|
| 365 |
+
def forward(self, input):
|
| 366 |
+
"""Standard forward"""
|
| 367 |
+
return self.model(input)
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
class ResnetBlock(nn.Module):
|
| 371 |
+
"""Define a Resnet block"""
|
| 372 |
+
|
| 373 |
+
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
| 374 |
+
"""Initialize the Resnet block
|
| 375 |
+
|
| 376 |
+
A resnet block is a conv block with skip connections
|
| 377 |
+
We construct a conv block with build_conv_block function,
|
| 378 |
+
and implement skip connections in <forward> function.
|
| 379 |
+
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
|
| 380 |
+
"""
|
| 381 |
+
super(ResnetBlock, self).__init__()
|
| 382 |
+
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
|
| 383 |
+
|
| 384 |
+
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
| 385 |
+
"""Construct a convolutional block.
|
| 386 |
+
|
| 387 |
+
Parameters:
|
| 388 |
+
dim (int) -- the number of channels in the conv layer.
|
| 389 |
+
padding_type (str) -- the name of padding layer: reflect | replicate | zero
|
| 390 |
+
norm_layer -- normalization layer
|
| 391 |
+
use_dropout (bool) -- if use dropout layers.
|
| 392 |
+
use_bias (bool) -- if the conv layer uses bias or not
|
| 393 |
+
|
| 394 |
+
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
|
| 395 |
+
"""
|
| 396 |
+
conv_block = []
|
| 397 |
+
p = 0
|
| 398 |
+
if padding_type == 'reflect':
|
| 399 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
| 400 |
+
elif padding_type == 'replicate':
|
| 401 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
| 402 |
+
elif padding_type == 'zero':
|
| 403 |
+
p = 1
|
| 404 |
+
else:
|
| 405 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
| 406 |
+
|
| 407 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
|
| 408 |
+
if use_dropout:
|
| 409 |
+
conv_block += [nn.Dropout(0.5)]
|
| 410 |
+
|
| 411 |
+
p = 0
|
| 412 |
+
if padding_type == 'reflect':
|
| 413 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
| 414 |
+
elif padding_type == 'replicate':
|
| 415 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
| 416 |
+
elif padding_type == 'zero':
|
| 417 |
+
p = 1
|
| 418 |
+
else:
|
| 419 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
| 420 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
|
| 421 |
+
|
| 422 |
+
return nn.Sequential(*conv_block)
|
| 423 |
+
|
| 424 |
+
def forward(self, x):
|
| 425 |
+
"""Forward function (with skip connections)"""
|
| 426 |
+
out = x + self.conv_block(x) # add skip connections
|
| 427 |
+
return out
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
class UnetGenerator(nn.Module):
|
| 431 |
+
"""Create a Unet-based generator"""
|
| 432 |
+
|
| 433 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
| 434 |
+
"""Construct a Unet generator
|
| 435 |
+
Parameters:
|
| 436 |
+
input_nc (int) -- the number of channels in input images
|
| 437 |
+
output_nc (int) -- the number of channels in output images
|
| 438 |
+
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
|
| 439 |
+
image of size 128x128 will become of size 1x1 # at the bottleneck
|
| 440 |
+
ngf (int) -- the number of filters in the last conv layer
|
| 441 |
+
norm_layer -- normalization layer
|
| 442 |
+
|
| 443 |
+
We construct the U-Net from the innermost layer to the outermost layer.
|
| 444 |
+
It is a recursive process.
|
| 445 |
+
"""
|
| 446 |
+
super(UnetGenerator, self).__init__()
|
| 447 |
+
# construct unet structure
|
| 448 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
|
| 449 |
+
for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
|
| 450 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
|
| 451 |
+
# gradually reduce the number of filters from ngf * 8 to ngf
|
| 452 |
+
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
| 453 |
+
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
| 454 |
+
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
| 455 |
+
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
|
| 456 |
+
|
| 457 |
+
def forward(self, input):
|
| 458 |
+
"""Standard forward"""
|
| 459 |
+
return self.model(input)
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
class UnetSkipConnectionBlock(nn.Module):
|
| 463 |
+
"""Defines the Unet submodule with skip connection.
|
| 464 |
+
X -------------------identity----------------------
|
| 465 |
+
|-- downsampling -- |submodule| -- upsampling --|
|
| 466 |
+
"""
|
| 467 |
+
|
| 468 |
+
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
| 469 |
+
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
| 470 |
+
"""Construct a Unet submodule with skip connections.
|
| 471 |
+
|
| 472 |
+
Parameters:
|
| 473 |
+
outer_nc (int) -- the number of filters in the outer conv layer
|
| 474 |
+
inner_nc (int) -- the number of filters in the inner conv layer
|
| 475 |
+
input_nc (int) -- the number of channels in input images/features
|
| 476 |
+
submodule (UnetSkipConnectionBlock) -- previously defined submodules
|
| 477 |
+
outermost (bool) -- if this module is the outermost module
|
| 478 |
+
innermost (bool) -- if this module is the innermost module
|
| 479 |
+
norm_layer -- normalization layer
|
| 480 |
+
user_dropout (bool) -- if use dropout layers.
|
| 481 |
+
"""
|
| 482 |
+
super(UnetSkipConnectionBlock, self).__init__()
|
| 483 |
+
self.outermost = outermost
|
| 484 |
+
if type(norm_layer) == functools.partial:
|
| 485 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
| 486 |
+
else:
|
| 487 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
| 488 |
+
if input_nc is None:
|
| 489 |
+
input_nc = outer_nc
|
| 490 |
+
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
|
| 491 |
+
stride=2, padding=1, bias=use_bias)
|
| 492 |
+
downrelu = nn.LeakyReLU(0.2, True)
|
| 493 |
+
downnorm = norm_layer(inner_nc)
|
| 494 |
+
uprelu = nn.ReLU(True)
|
| 495 |
+
upnorm = norm_layer(outer_nc)
|
| 496 |
+
|
| 497 |
+
if outermost:
|
| 498 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
| 499 |
+
kernel_size=4, stride=2,
|
| 500 |
+
padding=1)
|
| 501 |
+
down = [downconv]
|
| 502 |
+
up = [uprelu, upconv, nn.Tanh()]
|
| 503 |
+
model = down + [submodule] + up
|
| 504 |
+
elif innermost:
|
| 505 |
+
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
|
| 506 |
+
kernel_size=4, stride=2,
|
| 507 |
+
padding=1, bias=use_bias)
|
| 508 |
+
down = [downrelu, downconv]
|
| 509 |
+
up = [uprelu, upconv, upnorm]
|
| 510 |
+
model = down + up
|
| 511 |
+
else:
|
| 512 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
| 513 |
+
kernel_size=4, stride=2,
|
| 514 |
+
padding=1, bias=use_bias)
|
| 515 |
+
down = [downrelu, downconv, downnorm]
|
| 516 |
+
up = [uprelu, upconv, upnorm]
|
| 517 |
+
|
| 518 |
+
if use_dropout:
|
| 519 |
+
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
| 520 |
+
else:
|
| 521 |
+
model = down + [submodule] + up
|
| 522 |
+
|
| 523 |
+
self.model = nn.Sequential(*model)
|
| 524 |
+
|
| 525 |
+
def forward(self, x):
|
| 526 |
+
if self.outermost:
|
| 527 |
+
return self.model(x)
|
| 528 |
+
else: # add skip connections
|
| 529 |
+
return torch.cat([x, self.model(x)], 1)
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
class NLayerDiscriminator(nn.Module):
|
| 533 |
+
"""Defines a PatchGAN discriminator"""
|
| 534 |
+
|
| 535 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
|
| 536 |
+
"""Construct a PatchGAN discriminator
|
| 537 |
+
|
| 538 |
+
Parameters:
|
| 539 |
+
input_nc (int) -- the number of channels in input images
|
| 540 |
+
ndf (int) -- the number of filters in the last conv layer
|
| 541 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
| 542 |
+
norm_layer -- normalization layer
|
| 543 |
+
"""
|
| 544 |
+
super(NLayerDiscriminator, self).__init__()
|
| 545 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
| 546 |
+
use_bias = norm_layer.func != nn.BatchNorm2d
|
| 547 |
+
else:
|
| 548 |
+
use_bias = norm_layer != nn.BatchNorm2d
|
| 549 |
+
|
| 550 |
+
kw = 4
|
| 551 |
+
padw = 1
|
| 552 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
| 553 |
+
nf_mult = 1
|
| 554 |
+
nf_mult_prev = 1
|
| 555 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
| 556 |
+
nf_mult_prev = nf_mult
|
| 557 |
+
nf_mult = min(2 ** n, 8)
|
| 558 |
+
sequence += [
|
| 559 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
| 560 |
+
norm_layer(ndf * nf_mult),
|
| 561 |
+
nn.LeakyReLU(0.2, True)
|
| 562 |
+
]
|
| 563 |
+
|
| 564 |
+
nf_mult_prev = nf_mult
|
| 565 |
+
nf_mult = min(2 ** n_layers, 8)
|
| 566 |
+
sequence += [
|
| 567 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
| 568 |
+
norm_layer(ndf * nf_mult),
|
| 569 |
+
nn.LeakyReLU(0.2, True)
|
| 570 |
+
]
|
| 571 |
+
|
| 572 |
+
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
| 573 |
+
self.model = nn.Sequential(*sequence)
|
| 574 |
+
|
| 575 |
+
def forward(self, input):
|
| 576 |
+
"""Standard forward."""
|
| 577 |
+
return self.model(input)
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
class PixelDiscriminator(nn.Module):
|
| 581 |
+
"""Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
|
| 582 |
+
|
| 583 |
+
def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
|
| 584 |
+
"""Construct a 1x1 PatchGAN discriminator
|
| 585 |
+
|
| 586 |
+
Parameters:
|
| 587 |
+
input_nc (int) -- the number of channels in input images
|
| 588 |
+
ndf (int) -- the number of filters in the last conv layer
|
| 589 |
+
norm_layer -- normalization layer
|
| 590 |
+
"""
|
| 591 |
+
super(PixelDiscriminator, self).__init__()
|
| 592 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
| 593 |
+
use_bias = norm_layer.func != nn.InstanceNorm2d
|
| 594 |
+
else:
|
| 595 |
+
use_bias = norm_layer != nn.InstanceNorm2d
|
| 596 |
+
|
| 597 |
+
self.net = [
|
| 598 |
+
nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
|
| 599 |
+
nn.LeakyReLU(0.2, True),
|
| 600 |
+
nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
|
| 601 |
+
norm_layer(ndf * 2),
|
| 602 |
+
nn.LeakyReLU(0.2, True),
|
| 603 |
+
nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
|
| 604 |
+
|
| 605 |
+
self.net = nn.Sequential(*self.net)
|
| 606 |
+
|
| 607 |
+
def forward(self, input):
|
| 608 |
+
"""Standard forward."""
|
| 609 |
+
return self.net(input)
|
models/roadnet_model.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Author: Yahui Liu <yahui.liu@uintn.it>
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import itertools
|
| 6 |
+
from .base_model import BaseModel
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from .roadnet_networks import define_roadnet
|
| 9 |
+
|
| 10 |
+
class RoadNetModel(BaseModel):
|
| 11 |
+
"""
|
| 12 |
+
This class implements the RoadNet model.
|
| 13 |
+
RoadNet paper: https://ieeexplore.ieee.org/document/8506600
|
| 14 |
+
"""
|
| 15 |
+
@staticmethod
|
| 16 |
+
def modify_commandline_options(parser, is_train=True):
|
| 17 |
+
"""Add new dataset-specific options, and rewrite default values for existing options."""
|
| 18 |
+
return parser
|
| 19 |
+
|
| 20 |
+
def __init__(self, opt):
|
| 21 |
+
"""Initialize the RoadNet class.
|
| 22 |
+
|
| 23 |
+
Parameters:
|
| 24 |
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
| 25 |
+
"""
|
| 26 |
+
BaseModel.__init__(self, opt)
|
| 27 |
+
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
|
| 28 |
+
self.loss_names = ['segment', 'edge', 'centerline']
|
| 29 |
+
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
|
| 30 |
+
self.visual_names = ['image', 'label_gt', 'label_pred']
|
| 31 |
+
# specify the models you want to save to the disk.
|
| 32 |
+
self.model_names = ['G']
|
| 33 |
+
|
| 34 |
+
# define networks
|
| 35 |
+
self.netG = define_roadnet(opt.input_nc,
|
| 36 |
+
opt.output_nc,
|
| 37 |
+
opt.ngf,
|
| 38 |
+
opt.norm,
|
| 39 |
+
opt.use_selu,
|
| 40 |
+
opt.init_type,
|
| 41 |
+
opt.init_gain,
|
| 42 |
+
self.gpu_ids)
|
| 43 |
+
|
| 44 |
+
if self.isTrain:
|
| 45 |
+
# define loss functions
|
| 46 |
+
self.weight_segment_side = [0.5, 0.75, 1.0, 0.75, 0.5, 1.0]
|
| 47 |
+
self.weight_others_side = [0.5, 0.75, 1.0, 0.75, 1.0]
|
| 48 |
+
|
| 49 |
+
# initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
|
| 50 |
+
self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, eps=1e-3, weight_decay=2e-4)
|
| 51 |
+
#self.optimizer = torch.optim.SGD(self.netG.parameters(), lr=opt.lr, momentum=0.9, weight_decay=2e-4)
|
| 52 |
+
self.optimizers.append(self.optimizer)
|
| 53 |
+
|
| 54 |
+
def _get_balanced_sigmoid_cross_entropy(self,x):
|
| 55 |
+
count_neg = torch.sum(1. - x)
|
| 56 |
+
count_pos = torch.sum(x)
|
| 57 |
+
beta = count_neg / (count_neg + count_pos)
|
| 58 |
+
pos_weight = beta / (1 - beta)
|
| 59 |
+
cost = torch.nn.BCEWithLogitsLoss(size_average=True, reduce=True, pos_weight=pos_weight)
|
| 60 |
+
return cost, 1-beta
|
| 61 |
+
|
| 62 |
+
def set_input(self, input):
|
| 63 |
+
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
| 64 |
+
|
| 65 |
+
Parameters:
|
| 66 |
+
input (dict): include the data itself and its metadata information.
|
| 67 |
+
"""
|
| 68 |
+
self.image = input['image'].to(self.device)
|
| 69 |
+
self.segment_gt = input['segment'].to(self.device)
|
| 70 |
+
self.edge_gt = input['edge'].to(self.device)
|
| 71 |
+
self.centerline_gt = input['centerline'].to(self.device)
|
| 72 |
+
self.image_paths = input['A_paths']
|
| 73 |
+
if self.isTrain:
|
| 74 |
+
self.criterion_seg, self.beta_seg = self._get_balanced_sigmoid_cross_entropy(self.segment_gt)
|
| 75 |
+
self.criterion_edg, self.beta_edg = self._get_balanced_sigmoid_cross_entropy(self.edge_gt)
|
| 76 |
+
self.criterion_cnt, self.beta_cnt = self._get_balanced_sigmoid_cross_entropy(self.centerline_gt)
|
| 77 |
+
|
| 78 |
+
def forward(self):
|
| 79 |
+
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
| 80 |
+
self.segments, self.edges, self.centerlines = self.netG(self.image)
|
| 81 |
+
|
| 82 |
+
# for visualization
|
| 83 |
+
segment_gt_viz = (self.segment_gt-0.5)/0.5
|
| 84 |
+
edge_gt_viz = (self.edge_gt-0.5)/0.5
|
| 85 |
+
centerline_gt_viz = (self.centerline_gt-0.5)/0.5
|
| 86 |
+
self.label_gt = torch.cat([centerline_gt_viz, edge_gt_viz, segment_gt_viz], dim=1)
|
| 87 |
+
|
| 88 |
+
segment_fused = (torch.sigmoid(self.segments[-1])-0.5)/0.5
|
| 89 |
+
edge_fused = (torch.sigmoid(self.edges[-1])-0.5)/0.5
|
| 90 |
+
centerlines_fused = (torch.sigmoid(self.centerlines[-1])-0.5)/0.5
|
| 91 |
+
self.label_pred = torch.cat([centerlines_fused, edge_fused, segment_fused], dim=1)
|
| 92 |
+
|
| 93 |
+
def backward(self):
|
| 94 |
+
"""Calculate the loss"""
|
| 95 |
+
self.loss_segment = torch.mean((torch.sigmoid(self.segments[-1])-self.segment_gt)**2) * 0.5
|
| 96 |
+
if self.segment_gt.sum() > 0.0: # ignore blank ones
|
| 97 |
+
for out, w in zip(self.segments, self.weight_segment_side):
|
| 98 |
+
self.loss_segment += self.criterion_seg(out, self.segment_gt) * self.beta_seg * w
|
| 99 |
+
|
| 100 |
+
self.loss_edge = torch.mean((torch.sigmoid(self.edges[-1])-self.edge_gt)**2) * 0.5
|
| 101 |
+
if self.edge_gt.sum() > 0.0:
|
| 102 |
+
for out, w in zip(self.edges, self.weight_others_side):
|
| 103 |
+
self.loss_edge += self.criterion_edg(out, self.edge_gt) * self.beta_edg * w
|
| 104 |
+
|
| 105 |
+
self.loss_centerline = torch.mean((torch.sigmoid(self.centerlines[-1])-self.centerline_gt)**2) * 0.5
|
| 106 |
+
if self.centerline_gt.sum() > 0.0:
|
| 107 |
+
for out, w in zip(self.centerlines, self.weight_others_side):
|
| 108 |
+
self.loss_centerline += self.criterion_cnt(out, self.centerline_gt) * self.beta_cnt * w
|
| 109 |
+
|
| 110 |
+
self.loss_total = self.loss_segment + self.loss_edge + self.loss_centerline
|
| 111 |
+
self.loss_total.backward()
|
| 112 |
+
|
| 113 |
+
def optimize_parameters(self, epoch=None):
|
| 114 |
+
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
|
| 115 |
+
|
| 116 |
+
# forward
|
| 117 |
+
self.forward() # compute predictions.
|
| 118 |
+
self.optimizer.zero_grad() # set G's gradients to zero
|
| 119 |
+
self.backward() # calculate gradients for G
|
| 120 |
+
self.optimizer.step() # update G's weights
|
models/roadnet_networks.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#! -*- coding: utf-8 -*-
|
| 2 |
+
# Author: Yahui Liu <yahui.liu@unitn.it>
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Reference:
|
| 6 |
+
|
| 7 |
+
RoadNet: Learning to Comprehensively Analyze Road Networks in Complex Urban Scenes
|
| 8 |
+
From High-Resolution Remotely Sensed Images.
|
| 9 |
+
https://ieeexplore.ieee.org/document/8506600
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from .networks import get_norm_layer, init_net
|
| 16 |
+
|
| 17 |
+
class RoadNet(nn.Module):
|
| 18 |
+
def __init__(self, in_nc, out_nc, ngf, norm='batch', use_selu=1):
|
| 19 |
+
super(RoadNet, self).__init__()
|
| 20 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
| 21 |
+
|
| 22 |
+
#------------road surface segmentation------------#
|
| 23 |
+
self.segment_conv1 = nn.Sequential(*self._conv_block(in_nc, ngf, norm_layer, use_selu, num_block=2))
|
| 24 |
+
self.side_segment_conv1 = nn.Conv2d(ngf, out_nc, kernel_size=1, stride=1, bias=False)
|
| 25 |
+
|
| 26 |
+
self.segment_conv2 = nn.Sequential(*self._conv_block(ngf, ngf*2, norm_layer, use_selu, num_block=2))
|
| 27 |
+
self.side_segment_conv2 = nn.Conv2d(ngf*2, out_nc, kernel_size=1, stride=1, bias=False)
|
| 28 |
+
|
| 29 |
+
self.segment_conv3 = nn.Sequential(*self._conv_block(ngf*2, ngf*4, norm_layer, use_selu, num_block=3))
|
| 30 |
+
self.side_segment_conv3 = nn.Conv2d(ngf*4, out_nc, kernel_size=1, stride=1, bias=False)
|
| 31 |
+
|
| 32 |
+
self.segment_conv4 = nn.Sequential(*self._conv_block(ngf*4, ngf*8, norm_layer, use_selu, num_block=3))
|
| 33 |
+
self.side_segment_conv4 = nn.Conv2d(ngf*8, out_nc, kernel_size=1, stride=1, bias=False)
|
| 34 |
+
|
| 35 |
+
self.segment_conv5 = nn.Sequential(*self._conv_block(ngf*8, ngf*8, norm_layer, use_selu, num_block=3))
|
| 36 |
+
self.side_segment_conv5 = nn.Conv2d(ngf*8, out_nc, kernel_size=1, stride=1, bias=False)
|
| 37 |
+
|
| 38 |
+
self.fuse_segment_conv = nn.Conv2d(out_nc*5, out_nc, kernel_size=1, stride=1, bias=False)
|
| 39 |
+
|
| 40 |
+
ngf2 = ngf//2
|
| 41 |
+
#------------road edge detection------------#
|
| 42 |
+
self.edge_conv1 = nn.Sequential(*self._conv_block(in_nc+out_nc, ngf2, norm_layer, use_selu, num_block=2))
|
| 43 |
+
self.side_edge_conv1 = nn.Conv2d(ngf2, out_nc, kernel_size=1, stride=1, bias=False)
|
| 44 |
+
|
| 45 |
+
self.edge_conv2 = nn.Sequential(*self._conv_block(ngf2, ngf2*2, norm_layer, use_selu, num_block=2))
|
| 46 |
+
self.side_edge_conv2 = nn.Conv2d(ngf2*2, out_nc, kernel_size=1, stride=1, bias=False)
|
| 47 |
+
|
| 48 |
+
self.edge_conv3 = nn.Sequential(*self._conv_block(ngf2*2, ngf2*4, norm_layer, use_selu, num_block=2))
|
| 49 |
+
self.side_edge_conv3 = nn.Conv2d(ngf2*4, out_nc, kernel_size=1, stride=1, bias=False)
|
| 50 |
+
|
| 51 |
+
self.edge_conv4 = nn.Sequential(*self._conv_block(ngf2*4, ngf2*8, norm_layer, use_selu, num_block=2))
|
| 52 |
+
self.side_edge_conv4 = nn.Conv2d(ngf2*8, out_nc, kernel_size=1, stride=1, bias=False)
|
| 53 |
+
|
| 54 |
+
self.fuse_edge_conv = nn.Conv2d(out_nc*4, out_nc, kernel_size=1, stride=1, bias=False)
|
| 55 |
+
|
| 56 |
+
#------------road centerline extraction------------#
|
| 57 |
+
self.centerline_conv1 = nn.Sequential(*self._conv_block(in_nc+out_nc, ngf2, norm_layer, use_selu, num_block=2))
|
| 58 |
+
self.side_centerline_conv1 = nn.Conv2d(ngf2, out_nc, kernel_size=1, stride=1, bias=False)
|
| 59 |
+
|
| 60 |
+
self.centerline_conv2 = nn.Sequential(*self._conv_block(ngf2, ngf2*2, norm_layer, use_selu, num_block=2))
|
| 61 |
+
self.side_centerline_conv2 = nn.Conv2d(ngf2*2, out_nc, kernel_size=1, stride=1, bias=False)
|
| 62 |
+
|
| 63 |
+
self.centerline_conv3 = nn.Sequential(*self._conv_block(ngf2*2, ngf2*4, norm_layer, use_selu, num_block=2))
|
| 64 |
+
self.side_centerline_conv3 = nn.Conv2d(ngf2*4, out_nc, kernel_size=1, stride=1, bias=False)
|
| 65 |
+
|
| 66 |
+
self.centerline_conv4 = nn.Sequential(*self._conv_block(ngf2*4, ngf2*8, norm_layer, use_selu, num_block=2))
|
| 67 |
+
self.side_centerline_conv4 = nn.Conv2d(ngf2*8, out_nc, kernel_size=1, stride=1, bias=False)
|
| 68 |
+
|
| 69 |
+
self.fuse_centerline_conv = nn.Conv2d(out_nc*4, out_nc, kernel_size=1, stride=1, bias=False)
|
| 70 |
+
|
| 71 |
+
self.maxpool = nn.MaxPool2d(2, stride=2)
|
| 72 |
+
|
| 73 |
+
#self.up2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
| 74 |
+
#self.up4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
|
| 75 |
+
#self.up8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
|
| 76 |
+
#self.up16 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True)
|
| 77 |
+
|
| 78 |
+
def _conv_block(self, in_nc, out_nc, norm_layer, use_selu, num_block=2, kernel_size=3,
|
| 79 |
+
stride=1, padding=1, bias=True):
|
| 80 |
+
conv = []
|
| 81 |
+
for i in range(num_block):
|
| 82 |
+
cur_in_nc = in_nc if i == 0 else out_nc
|
| 83 |
+
conv += [nn.Conv2d(cur_in_nc, out_nc, kernel_size=kernel_size, stride=stride,
|
| 84 |
+
padding=padding, bias=bias)]
|
| 85 |
+
if use_selu:
|
| 86 |
+
conv += [nn.SeLU(True)]
|
| 87 |
+
else:
|
| 88 |
+
conv += [norm_layer(out_nc), nn.ReLU(True)]
|
| 89 |
+
return conv
|
| 90 |
+
|
| 91 |
+
def _segment_forward(self, x):
|
| 92 |
+
"""
|
| 93 |
+
predict road surface segmentation
|
| 94 |
+
:param: x, image tensor, [N, C, H, W]
|
| 95 |
+
"""
|
| 96 |
+
h,w = x.size()[2:]
|
| 97 |
+
# main stream features
|
| 98 |
+
conv1 = self.segment_conv1(x)
|
| 99 |
+
conv2 = self.segment_conv2(self.maxpool(conv1))
|
| 100 |
+
conv3 = self.segment_conv3(self.maxpool(conv2))
|
| 101 |
+
conv4 = self.segment_conv4(self.maxpool(conv3))
|
| 102 |
+
conv5 = self.segment_conv5(self.maxpool(conv4))
|
| 103 |
+
# side output features
|
| 104 |
+
side_output1 = self.side_segment_conv1(conv1)
|
| 105 |
+
side_output2 = self.side_segment_conv2(conv2)
|
| 106 |
+
side_output3 = self.side_segment_conv3(conv3)
|
| 107 |
+
side_output4 = self.side_segment_conv4(conv4)
|
| 108 |
+
side_output5 = self.side_segment_conv5(conv5)
|
| 109 |
+
# upsampling side output features
|
| 110 |
+
side_output2 = F.interpolate(side_output2, size=(h, w), mode='bilinear', align_corners=True) #self.up2(side_output2)
|
| 111 |
+
side_output3 = F.interpolate(side_output3, size=(h, w), mode='bilinear', align_corners=True) #self.up4(side_output3)
|
| 112 |
+
side_output4 = F.interpolate(side_output4, size=(h, w), mode='bilinear', align_corners=True) #self.up8(side_output4)
|
| 113 |
+
side_output5 = F.interpolate(side_output5, size=(h, w), mode='bilinear', align_corners=True) #self.up16(side_output5)
|
| 114 |
+
|
| 115 |
+
fused = self.fuse_segment_conv(torch.cat([
|
| 116 |
+
side_output1,
|
| 117 |
+
side_output2,
|
| 118 |
+
side_output3,
|
| 119 |
+
side_output4,
|
| 120 |
+
side_output5], dim=1))
|
| 121 |
+
return [side_output1, side_output2, side_output3, side_output4, side_output5, fused]
|
| 122 |
+
|
| 123 |
+
def _edge_forward(self, x):
|
| 124 |
+
"""
|
| 125 |
+
predict road edge
|
| 126 |
+
:param: x, [image tensor, predicted segmentation tensor], [N, C+1, H, W]
|
| 127 |
+
"""
|
| 128 |
+
h, w = x.size()[2:]
|
| 129 |
+
# main stream features
|
| 130 |
+
conv1 = self.edge_conv1(x)
|
| 131 |
+
conv2 = self.edge_conv2(self.maxpool(conv1))
|
| 132 |
+
conv3 = self.edge_conv3(self.maxpool(conv2))
|
| 133 |
+
conv4 = self.edge_conv4(self.maxpool(conv3))
|
| 134 |
+
# side output features
|
| 135 |
+
side_output1 = self.side_edge_conv1(conv1)
|
| 136 |
+
side_output2 = self.side_edge_conv2(conv2)
|
| 137 |
+
side_output3 = self.side_edge_conv3(conv3)
|
| 138 |
+
side_output4 = self.side_edge_conv4(conv4)
|
| 139 |
+
# upsampling side output features
|
| 140 |
+
side_output2 = F.interpolate(side_output2, size=(h, w), mode='bilinear', align_corners=True) #self.up2(side_output2)
|
| 141 |
+
side_output3 = F.interpolate(side_output3, size=(h, w), mode='bilinear', align_corners=True) #self.up4(side_output3)
|
| 142 |
+
side_output4 = F.interpolate(side_output4, size=(h, w), mode='bilinear', align_corners=True) #self.up8(side_output4)
|
| 143 |
+
fused = self.fuse_edge_conv(torch.cat([
|
| 144 |
+
side_output1,
|
| 145 |
+
side_output2,
|
| 146 |
+
side_output3,
|
| 147 |
+
side_output4], dim=1))
|
| 148 |
+
return [side_output1, side_output2, side_output3, side_output4, fused]
|
| 149 |
+
|
| 150 |
+
def _centerline_forward(self, x):
|
| 151 |
+
"""
|
| 152 |
+
predict road edge
|
| 153 |
+
:param: x, [image tensor, predicted segmentation tensor], [N, C+1, H, W]
|
| 154 |
+
"""
|
| 155 |
+
h,w = x.size()[2:]
|
| 156 |
+
# main stream features
|
| 157 |
+
conv1 = self.centerline_conv1(x)
|
| 158 |
+
conv2 = self.centerline_conv2(self.maxpool(conv1))
|
| 159 |
+
conv3 = self.centerline_conv3(self.maxpool(conv2))
|
| 160 |
+
conv4 = self.centerline_conv4(self.maxpool(conv3))
|
| 161 |
+
# side output features
|
| 162 |
+
side_output1 = self.side_centerline_conv1(conv1)
|
| 163 |
+
side_output2 = self.side_centerline_conv2(conv2)
|
| 164 |
+
side_output3 = self.side_centerline_conv3(conv3)
|
| 165 |
+
side_output4 = self.side_centerline_conv4(conv4)
|
| 166 |
+
# upsampling side output features
|
| 167 |
+
side_output2 = F.interpolate(side_output2, size=(h, w), mode='bilinear', align_corners=True) #self.up2(side_output2)
|
| 168 |
+
side_output3 = F.interpolate(side_output3, size=(h, w), mode='bilinear', align_corners=True) #self.up4(side_output3)
|
| 169 |
+
side_output4 = F.interpolate(side_output4, size=(h, w), mode='bilinear', align_corners=True) #self.up8(side_output4)
|
| 170 |
+
fused = self.fuse_centerline_conv(torch.cat([
|
| 171 |
+
side_output1,
|
| 172 |
+
side_output2,
|
| 173 |
+
side_output3,
|
| 174 |
+
side_output4], dim=1))
|
| 175 |
+
return [side_output1, side_output2, side_output3, side_output4, fused]
|
| 176 |
+
|
| 177 |
+
def forward(self, x):
|
| 178 |
+
segments = self._segment_forward(x)
|
| 179 |
+
|
| 180 |
+
x_ = torch.cat([x, segments[-1]], dim=1)
|
| 181 |
+
edges = self._edge_forward(x_)
|
| 182 |
+
centerlines = self._centerline_forward(x_)
|
| 183 |
+
return segments, edges, centerlines
|
| 184 |
+
|
| 185 |
+
def define_roadnet(in_nc,
|
| 186 |
+
out_nc,
|
| 187 |
+
ngf,
|
| 188 |
+
norm='batch',
|
| 189 |
+
use_selu=1,
|
| 190 |
+
init_type='xavier',
|
| 191 |
+
init_gain=0.02,
|
| 192 |
+
gpu_ids=[]):
|
| 193 |
+
net = RoadNet(in_nc, out_nc, ngf, norm, use_selu)
|
| 194 |
+
return init_net(net, init_type, init_gain, gpu_ids)
|