Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from torch.nn import init | |
| import functools | |
| from torch.optim import lr_scheduler | |
| import torch.nn.functional as F | |
| import math | |
| from einops import rearrange | |
| from .transformer_ops.transformer_function import TransformerEncoderLayer | |
| ###################################################################################### | |
| # Attention-Aware Layer | |
| ###################################################################################### | |
| class AttnAware(nn.Module): | |
| def __init__(self, input_nc, activation='gelu', norm='pixel', num_heads=2): | |
| super(AttnAware, self).__init__() | |
| activation_layer = get_nonlinearity_layer(activation) | |
| norm_layer = get_norm_layer(norm) | |
| head_dim = input_nc // num_heads | |
| self.num_heads = num_heads | |
| self.input_nc = input_nc | |
| self.scale = head_dim ** -0.5 | |
| self.query_conv = nn.Sequential( | |
| norm_layer(input_nc), | |
| activation_layer, | |
| nn.Conv2d(input_nc, input_nc, kernel_size=1) | |
| ) | |
| self.key_conv = nn.Sequential( | |
| norm_layer(input_nc), | |
| activation_layer, | |
| nn.Conv2d(input_nc, input_nc, kernel_size=1) | |
| ) | |
| self.weight = nn.Conv2d(self.num_heads*2, 2, kernel_size=1, stride=1) | |
| self.to_out = ResnetBlock(input_nc * 2, input_nc, 1, 0, activation, norm) | |
| def forward(self, x, pre=None, mask=None): | |
| B, C, W, H = x.size() | |
| q = self.query_conv(x).view(B, -1, W*H) | |
| k = self.key_conv(x).view(B, -1, W*H) | |
| v = x.view(B, -1, W*H) | |
| q = rearrange(q, 'b (h d) n -> b h n d', h=self.num_heads) | |
| k = rearrange(k, 'b (h d) n -> b h n d', h=self.num_heads) | |
| v = rearrange(v, 'b (h d) n -> b h n d', h=self.num_heads) | |
| dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale | |
| if pre is not None: | |
| # attention-aware weight | |
| B, head, N, N = dots.size() | |
| mask_n = mask.view(B, -1, 1, W * H).expand_as(dots) | |
| w_visible = (dots.detach() * mask_n).max(dim=-1, keepdim=True)[0] | |
| w_invisible = (dots.detach() * (1-mask_n)).max(dim=-1, keepdim=True)[0] | |
| weight = torch.cat([w_visible.view(B, head, W, H), w_invisible.view(B, head, W, H)], dim=1) | |
| weight = self.weight(weight) | |
| weight = F.softmax(weight, dim=1) | |
| # visible attention score | |
| pre_v = pre.view(B, -1, W*H) | |
| pre_v = rearrange(pre_v, 'b (h d) n -> b h n d', h=self.num_heads) | |
| dots_visible = torch.where(dots > 0, dots * mask_n, dots / (mask_n + 1e-8)) | |
| attn_visible = dots_visible.softmax(dim=-1) | |
| context_flow = torch.einsum('bhij, bhjd->bhid', attn_visible, pre_v) | |
| context_flow = rearrange(context_flow, 'b h n d -> b (h d) n').view(B, -1, W, H) | |
| # invisible attention score | |
| dots_invisible = torch.where(dots > 0, dots * (1 - mask_n), dots / ((1 - mask_n) + 1e-8)) | |
| attn_invisible = dots_invisible.softmax(dim=-1) | |
| self_attention = torch.einsum('bhij, bhjd->bhid', attn_invisible, v) | |
| self_attention = rearrange(self_attention, 'b h n d -> b (h d) n').view(B, -1, W, H) | |
| # out | |
| out = weight[:, :1, :, :]*context_flow + weight[:, 1:, :, :]*self_attention | |
| else: | |
| attn = dots.softmax(dim=-1) | |
| out = torch.einsum('bhij, bhjd->bhid', attn, v) | |
| out = rearrange(out, 'b h n d -> b (h d) n').view(B, -1, W, H) | |
| out = self.to_out(torch.cat([out, x], dim=1)) | |
| return out | |
| ###################################################################################### | |
| # base modules | |
| ###################################################################################### | |
| class NoiseInjection(nn.Module): | |
| def __init__(self): | |
| super(NoiseInjection, self).__init__() | |
| self.alpha = nn.Parameter(torch.zeros(1)) | |
| def forward(self, x, noise=None, mask=None): | |
| if noise is None: | |
| b, _, h, w = x.size() | |
| noise = x.new_empty(b, 1, h, w).normal_() | |
| if mask is not None: | |
| mask = F.interpolate(mask, size=x.size()[2:], mode='bilinear', align_corners=True) | |
| return x + self.alpha * noise * (1 - mask) # add noise only to the invisible part | |
| return x + self.alpha * noise | |
| class ConstantInput(nn.Module): | |
| """ | |
| add position embedding for each learned VQ word | |
| """ | |
| def __init__(self, channel, size=16): | |
| super().__init__() | |
| self.input = nn.Parameter(torch.randn(1, channel, size, size)) | |
| def forward(self, input): | |
| batch = input.shape[0] | |
| out = self.input.repeat(batch, 1, 1, 1) | |
| return out | |
| class UpSample(nn.Module): | |
| """ sample with convolutional operation | |
| :param input_nc: input channel | |
| :param with_conv: use convolution to refine the feature | |
| :param kernel_size: feature size | |
| :param return_mask: return mask for the confidential score | |
| """ | |
| def __init__(self, input_nc, with_conv=False, kernel_size=3, return_mask=False): | |
| super(UpSample, self).__init__() | |
| self.with_conv = with_conv | |
| self.return_mask = return_mask | |
| if self.with_conv: | |
| self.conv = PartialConv2d(input_nc, input_nc, kernel_size=kernel_size, stride=1, | |
| padding=int(int(kernel_size-1)/2), return_mask=True) | |
| def forward(self, x, mask=None): | |
| x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) | |
| mask = F.interpolate(mask, scale_factor=2, mode='bilinear', align_corners=True) if mask is not None else mask | |
| if self.with_conv: | |
| x, mask = self.conv(x, mask) | |
| if self.return_mask: | |
| return x, mask | |
| else: | |
| return x | |
| class DownSample(nn.Module): | |
| """ sample with convolutional operation | |
| :param input_nc: input channel | |
| :param with_conv: use convolution to refine the feature | |
| :param kernel_size: feature size | |
| :param return_mask: return mask for the confidential score | |
| """ | |
| def __init__(self, input_nc, with_conv=False, kernel_size=3, return_mask=False): | |
| super(DownSample, self).__init__() | |
| self.with_conv = with_conv | |
| self.return_mask = return_mask | |
| if self.with_conv: | |
| self.conv = PartialConv2d(input_nc, input_nc, kernel_size=kernel_size, stride=2, | |
| padding=int(int(kernel_size-1)/2), return_mask=True) | |
| def forward(self, x, mask=None): | |
| if self.with_conv: | |
| x, mask = self.conv(x, mask) | |
| else: | |
| x = F.avg_pool2d(x, kernel_size=2, stride=2) | |
| mask = F.avg_pool2d(mask, kernel_size=2, stride=2) if mask is not None else mask | |
| if self.return_mask: | |
| return x, mask | |
| else: | |
| return x | |
| class ResnetBlock(nn.Module): | |
| def __init__(self, input_nc, output_nc=None, kernel=3, dropout=0.0, activation='gelu', norm='pixel', return_mask=False): | |
| super(ResnetBlock, self).__init__() | |
| activation_layer = get_nonlinearity_layer(activation) | |
| norm_layer = get_norm_layer(norm) | |
| self.return_mask = return_mask | |
| output_nc = input_nc if output_nc is None else output_nc | |
| self.norm1 = norm_layer(input_nc) | |
| self.conv1 = PartialConv2d(input_nc, output_nc, kernel_size=kernel, padding=int((kernel-1)/2), return_mask=True) | |
| self.norm2 = norm_layer(output_nc) | |
| self.conv2 = PartialConv2d(output_nc, output_nc, kernel_size=kernel, padding=int((kernel-1)/2), return_mask=True) | |
| self.dropout = nn.Dropout(dropout) | |
| self.act = activation_layer | |
| if input_nc != output_nc: | |
| self.short = PartialConv2d(input_nc, output_nc, kernel_size=1, stride=1, padding=0) | |
| else: | |
| self.short = Identity() | |
| def forward(self, x, mask=None): | |
| x_short = self.short(x) | |
| x, mask = self.conv1(self.act(self.norm1(x)), mask) | |
| x, mask = self.conv2(self.dropout(self.act(self.norm2(x))), mask) | |
| if self.return_mask: | |
| return (x + x_short) / math.sqrt(2), mask | |
| else: | |
| return (x + x_short) / math.sqrt(2) | |
| class DiffEncoder(nn.Module): | |
| def __init__(self, input_nc, ngf=64, kernel_size=2, embed_dim=512, down_scale=4, num_res_blocks=2, dropout=0.0, | |
| rample_with_conv=True, activation='gelu', norm='pixel', use_attn=False): | |
| super(DiffEncoder, self).__init__() | |
| activation_layer = get_nonlinearity_layer(activation) | |
| norm_layer = get_norm_layer(norm) | |
| # start | |
| self.encode = PartialConv2d(input_nc, ngf, kernel_size=kernel_size, stride=1, padding=int((kernel_size-1)/2), return_mask=True) | |
| # down | |
| self.use_attn = use_attn | |
| self.down_scale = down_scale | |
| self.num_res_blocks = num_res_blocks | |
| self.down = nn.ModuleList() | |
| out_dim = ngf | |
| for i in range(down_scale): | |
| block = nn.ModuleList() | |
| down = nn.Module() | |
| in_dim = out_dim | |
| out_dim = int(in_dim * 2) | |
| down.downsample = DownSample(in_dim, rample_with_conv, kernel_size=2, return_mask=True) | |
| for i_block in range(num_res_blocks): | |
| block.append(ResnetBlock(in_dim, out_dim, kernel_size, dropout, activation, norm, return_mask=True)) | |
| in_dim = out_dim | |
| down.block = block | |
| self.down.append(down) | |
| # middle | |
| self.mid = nn.Module() | |
| self.mid.block1 = ResnetBlock(out_dim, out_dim, kernel_size, dropout, activation, norm, return_mask=True) | |
| if self.use_attn: | |
| self.mid.attn = TransformerEncoderLayer(out_dim, kernel=1) | |
| self.mid.block2 = ResnetBlock(out_dim, out_dim, kernel_size, dropout, activation, norm, return_mask=True) | |
| # end | |
| self.conv_out = ResnetBlock(out_dim, embed_dim, kernel_size, dropout, activation, norm, return_mask=True) | |
| def forward(self, x, mask=None, return_mask=False): | |
| x, mask = self.encode(x, mask) | |
| # down sampling | |
| for i in range(self.down_scale): | |
| x, mask = self.down[i].downsample(x, mask) | |
| for i_block in range(self.num_res_blocks): | |
| x, mask = self.down[i].block[i_block](x, mask) | |
| # middle | |
| x, mask = self.mid.block1(x, mask) | |
| if self.use_attn: | |
| x = self.mid.attn(x) | |
| x, mask = self.mid.block2(x, mask) | |
| # end | |
| x, mask = self.conv_out(x, mask) | |
| if return_mask: | |
| return x, mask | |
| return x | |
| class DiffDecoder(nn.Module): | |
| def __init__(self, output_nc, ngf=64, kernel_size=3, embed_dim=512, up_scale=4, num_res_blocks=2, dropout=0.0, word_size=16, | |
| rample_with_conv=True, activation='gelu', norm='pixel', add_noise=False, use_attn=True, use_pos=True): | |
| super(DiffDecoder, self).__init__() | |
| activation_layer = get_nonlinearity_layer(activation) | |
| norm_layer = get_norm_layer(norm) | |
| self.up_scale = up_scale | |
| self.num_res_blocks = num_res_blocks | |
| self.add_noise = add_noise | |
| self.use_attn = use_attn | |
| self.use_pos = use_pos | |
| in_dim = ngf * (2 ** self.up_scale) | |
| # start | |
| if use_pos: | |
| self.pos_embed = ConstantInput(embed_dim, size=word_size) | |
| self.conv_in = PartialConv2d(embed_dim, in_dim, kernel_size=kernel_size, stride=1, padding=int((kernel_size-1)/2)) | |
| # middle | |
| self.mid = nn.Module() | |
| self.mid.block1 = ResnetBlock(in_dim, in_dim, kernel_size, dropout, activation, norm) | |
| if self.use_attn: | |
| self.mid.attn = TransformerEncoderLayer(in_dim, kernel=1) | |
| self.mid.block2 = ResnetBlock(in_dim, in_dim, kernel_size, dropout, activation, norm) | |
| # up | |
| self.up = nn.ModuleList() | |
| out_dim = in_dim | |
| for i in range(up_scale): | |
| block = nn.ModuleList() | |
| attn = nn.ModuleList() | |
| noise = nn.ModuleList() | |
| up = nn.Module() | |
| in_dim = out_dim | |
| out_dim = int(in_dim / 2) | |
| for i_block in range(num_res_blocks): | |
| if add_noise: | |
| noise.append(NoiseInjection()) | |
| block.append(ResnetBlock(in_dim, out_dim, kernel_size, dropout, activation, norm)) | |
| in_dim = out_dim | |
| if i == 0 and self.use_attn: | |
| attn.append(TransformerEncoderLayer(in_dim, kernel=1)) | |
| up.block = block | |
| up.attn = attn | |
| up.noise = noise | |
| upsample = True if (i != 0) else False | |
| up.out = ToRGB(in_dim, output_nc, upsample, activation, norm) | |
| up.upsample = UpSample(in_dim, rample_with_conv, kernel_size=3) | |
| self.up.append(up) | |
| # end | |
| self.decode = ToRGB(in_dim, output_nc, True, activation, norm) | |
| def forward(self, x, mask=None): | |
| x = x + self.pos_embed(x) if self.use_pos else x | |
| x = self.conv_in(x) | |
| # middle | |
| x = self.mid.block1(x) | |
| if self.use_attn: | |
| x = self.mid.attn(x) | |
| x = self.mid.block2(x) | |
| # up | |
| skip = None | |
| for i in range(self.up_scale): | |
| for i_block in range(self.num_res_blocks): | |
| if self.add_noise: | |
| x = self.up[i].noise[i_block](x, mask=mask) | |
| x = self.up[i].block[i_block](x) | |
| if len(self.up[i].attn) > 0: | |
| x = self.up[i].attn[i_block](x) | |
| skip = self.up[i].out(x, skip) | |
| x = self.up[i].upsample(x) | |
| # end | |
| x = self.decode(x, skip) | |
| return x | |
| class LinearEncoder(nn.Module): | |
| def __init__(self, input_nc, kernel_size=16, embed_dim=512): | |
| super(LinearEncoder, self).__init__() | |
| self.encode = PartialConv2d(input_nc, embed_dim, kernel_size=kernel_size, stride=kernel_size, return_mask=True) | |
| def forward(self, x, mask=None, return_mask=False): | |
| x, mask = self.encode(x, mask) | |
| if return_mask: | |
| return x, mask | |
| return x | |
| class LinearDecoder(nn.Module): | |
| def __init__(self, output_nc, ngf=64, kernel_size=16, embed_dim=512, activation='gelu', norm='pixel'): | |
| super(LinearDecoder, self).__init__() | |
| activation_layer = get_nonlinearity_layer(activation) | |
| norm_layer = get_norm_layer(norm) | |
| self.decode = nn.Sequential( | |
| norm_layer(embed_dim), | |
| activation_layer, | |
| PartialConv2d(embed_dim, ngf*kernel_size*kernel_size, kernel_size=3, padding=1), | |
| nn.PixelShuffle(kernel_size), | |
| norm_layer(ngf), | |
| activation_layer, | |
| PartialConv2d(ngf, output_nc, kernel_size=3, padding=1) | |
| ) | |
| def forward(self, x, mask=None): | |
| x = self.decode(x) | |
| return torch.tanh(x) | |
| class ToRGB(nn.Module): | |
| def __init__(self, input_nc, output_nc, upsample=True, activation='gelu', norm='pixel'): | |
| super().__init__() | |
| activation_layer = get_nonlinearity_layer(activation) | |
| norm_layer = get_norm_layer(norm) | |
| if upsample: | |
| self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) | |
| input_nc = input_nc + output_nc | |
| self.conv = nn.Sequential( | |
| norm_layer(input_nc), | |
| activation_layer, | |
| PartialConv2d(input_nc, output_nc, kernel_size=3, padding=1) | |
| ) | |
| def forward(self, input, skip=None): | |
| if skip is not None: | |
| skip = self.upsample(skip) | |
| input = torch.cat([input, skip], dim=1) | |
| out = self.conv(input) | |
| return torch.tanh(out) | |
| ###################################################################################### | |
| # base function for network structure | |
| ###################################################################################### | |
| def get_scheduler(optimizer, opt): | |
| """Return a learning rate scheduler | |
| Parameters: | |
| optimizer -- the optimizer of the network | |
| opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. | |
| opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine | |
| """ | |
| if opt.lr_policy == 'linear': | |
| def lambda_rule(iter): | |
| lr_l = 1.0 - max(0, iter + opt.iter_count - opt.n_iter) / float(opt.n_iter_decay + 1) | |
| return lr_l | |
| scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) | |
| elif opt.lr_policy == 'plateau': | |
| scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) | |
| elif opt.lr_policy == 'cosine': | |
| scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) | |
| else: | |
| return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) | |
| return scheduler | |
| def init_weights(net, init_type='normal', init_gain=0.02, debug=False): | |
| """Initialize network weights. | |
| Parameters: | |
| net (network) -- network to be initialized | |
| init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal | |
| init_gain (float) -- scaling factor for normal, xavier and orthogonal. | |
| We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might | |
| work better for some applications. Feel free to try yourself. | |
| """ | |
| def init_func(m): # define the initialization function | |
| classname = m.__class__.__name__ | |
| if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): | |
| if debug: | |
| print(classname) | |
| if init_type == 'normal': | |
| init.normal_(m.weight.data, 0.0, init_gain) | |
| elif init_type == 'xavier': | |
| init.xavier_normal_(m.weight.data, gain=init_gain) | |
| elif init_type == 'kaiming': | |
| init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') | |
| elif init_type == 'orthogonal': | |
| init.orthogonal_(m.weight.data, gain=init_gain) | |
| else: | |
| raise NotImplementedError('initialization method [%s] is not implemented' % init_type) | |
| if hasattr(m, 'bias') and m.bias is not None: | |
| init.constant_(m.bias.data, 0.0) | |
| elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. | |
| init.normal_(m.weight.data, 1.0, init_gain) | |
| init.constant_(m.bias.data, 0.0) | |
| net.apply(init_func) # apply the initialization function <init_func> | |
| def init_net(net, init_type='normal', init_gain=0.02, debug=False, initialize_weights=True): | |
| """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights | |
| Parameters: | |
| net (network) -- the network to be initialized | |
| init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal | |
| gain (float) -- scaling factor for normal, xavier and orthogonal. | |
| gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 | |
| Return an initialized network. | |
| """ | |
| if initialize_weights: | |
| init_weights(net, init_type, init_gain=init_gain, debug=debug) | |
| return net | |
| class Identity(nn.Module): | |
| def forward(self, x): | |
| return x | |
| def get_norm_layer(norm_type='instance'): | |
| """Return a normalization layer | |
| Parameters: | |
| norm_type (str) -- the name of the normalization layer: batch | instance | none | |
| For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). | |
| For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. | |
| """ | |
| if norm_type == 'batch': | |
| norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) | |
| elif norm_type == 'instance': | |
| norm_layer = functools.partial(nn.InstanceNorm2d, affine=True) | |
| elif norm_type == 'pixel': | |
| norm_layer = functools.partial(PixelwiseNorm) | |
| elif norm_type == 'layer': | |
| norm_layer = functools.partial(nn.LayerNorm) | |
| elif norm_type == 'none': | |
| def norm_layer(x): return Identity() | |
| else: | |
| raise NotImplementedError('normalization layer [%s] is not found' % norm_type) | |
| return norm_layer | |
| def get_nonlinearity_layer(activation_type='PReLU'): | |
| """Get the activation layer for the networks""" | |
| if activation_type == 'relu': | |
| nonlinearity_layer = nn.ReLU() | |
| elif activation_type == 'gelu': | |
| nonlinearity_layer = nn.GELU() | |
| elif activation_type == 'leakyrelu': | |
| nonlinearity_layer = nn.LeakyReLU(0.2) | |
| elif activation_type == 'prelu': | |
| nonlinearity_layer = nn.PReLU() | |
| else: | |
| raise NotImplementedError('activation layer [%s] is not found' % activation_type) | |
| return nonlinearity_layer | |
| class PixelwiseNorm(nn.Module): | |
| def __init__(self, input_nc): | |
| super(PixelwiseNorm, self).__init__() | |
| self.init = False | |
| self.alpha = nn.Parameter(torch.ones(1, input_nc, 1, 1)) | |
| def forward(self, x, alpha=1e-8): | |
| """ | |
| forward pass of the module | |
| :param x: input activations volume | |
| :param alpha: small number for numerical stability | |
| :return: y => pixel normalized activations | |
| """ | |
| # x = x - x.mean(dim=1, keepdim=True) | |
| y = x.pow(2.).mean(dim=1, keepdim=True).add(alpha).rsqrt() # [N1HW] | |
| y = x * y # normalize the input x volume | |
| return self.alpha*y | |
| ############################################################################### | |
| # BSD 3-Clause License | |
| # | |
| # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Author & Contact: Guilin Liu (guilinl@nvidia.com) | |
| ############################################################################### | |
| class PartialConv2d(nn.Conv2d): | |
| def __init__(self, *args, **kwargs): | |
| # whether the mask is multi-channel or not | |
| if 'multi_channel' in kwargs: | |
| self.multi_channel = kwargs['multi_channel'] | |
| kwargs.pop('multi_channel') | |
| else: | |
| self.multi_channel = False | |
| if 'return_mask' in kwargs: | |
| self.return_mask = kwargs['return_mask'] | |
| kwargs.pop('return_mask') | |
| else: | |
| self.return_mask = False | |
| super(PartialConv2d, self).__init__(*args, **kwargs) | |
| if self.multi_channel: | |
| self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0], | |
| self.kernel_size[1]) | |
| else: | |
| self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1]) | |
| self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * \ | |
| self.weight_maskUpdater.shape[3] | |
| self.last_size = (None, None, None, None) | |
| self.update_mask = None | |
| self.mask_ratio = None | |
| def forward(self, input, mask_in=None): | |
| assert len(input.shape) == 4 | |
| if mask_in is not None or self.last_size != tuple(input.shape): | |
| self.last_size = tuple(input.shape) | |
| with torch.no_grad(): | |
| if self.weight_maskUpdater.type() != input.type(): | |
| self.weight_maskUpdater = self.weight_maskUpdater.to(input) | |
| if mask_in is None: | |
| # if mask is not provided, create a mask | |
| if self.multi_channel: | |
| mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2], | |
| input.data.shape[3]).to(input) | |
| else: | |
| mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3]).to(input) | |
| else: | |
| mask = mask_in | |
| self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, | |
| padding=self.padding, dilation=self.dilation, groups=1) | |
| # for mixed precision training, change 1e-8 to 1e-6 | |
| self.mask_ratio = self.slide_winsize / (self.update_mask + 1e-8) | |
| self.update_mask1 = torch.clamp(self.update_mask, 0, 1) | |
| self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask1) | |
| raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input) | |
| if self.bias is not None: | |
| bias_view = self.bias.view(1, self.out_channels, 1, 1) | |
| output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view | |
| output = torch.mul(output, self.update_mask1) | |
| else: | |
| output = torch.mul(raw_out, self.mask_ratio) | |
| if self.return_mask: | |
| return output, self.update_mask / self.slide_winsize # replace the valid value to confident score | |
| else: | |
| return output |