| |
| |
| |
| |
| from functools import partial |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from imaginaire.layers import Conv2dBlock |
|
|
|
|
| class NonLocal2dBlock(nn.Module): |
| r"""Self attention Layer |
| |
| Args: |
| in_channels (int): Number of channels in the input tensor. |
| scale (bool, optional, default=True): If ``True``, scale the |
| output by a learnable parameter. |
| clamp (bool, optional, default=``False``): If ``True``, clamp the |
| scaling parameter to (-1, 1). |
| weight_norm_type (str, optional, default='none'): |
| Type of weight normalization. |
| ``'none'``, ``'spectral'``, ``'weight'`` |
| or ``'weight_demod'``. |
| """ |
|
|
| def __init__(self, |
| in_channels, |
| scale=True, |
| clamp=False, |
| weight_norm_type='none'): |
| super(NonLocal2dBlock, self).__init__() |
| self.clamp = clamp |
| self.gamma = nn.Parameter(torch.zeros(1)) if scale else 1.0 |
| self.in_channels = in_channels |
| base_conv2d_block = partial(Conv2dBlock, |
| kernel_size=1, |
| stride=1, |
| padding=0, |
| weight_norm_type=weight_norm_type) |
| self.theta = base_conv2d_block(in_channels, in_channels // 8) |
| self.phi = base_conv2d_block(in_channels, in_channels // 8) |
| self.g = base_conv2d_block(in_channels, in_channels // 2) |
| self.out_conv = base_conv2d_block(in_channels // 2, in_channels) |
| self.softmax = nn.Softmax(dim=-1) |
| self.max_pool = nn.MaxPool2d(2) |
|
|
| def forward(self, x): |
| r""" |
| |
| Args: |
| x (tensor) : input feature maps (B X C X W X H) |
| Returns: |
| (tuple): |
| - out (tensor) : self attention value + input feature |
| - attention (tensor): B x N x N (N is Width*Height) |
| """ |
| n, c, h, w = x.size() |
| theta = self.theta(x).view(n, -1, h * w).permute(0, 2, 1) |
|
|
| phi = self.phi(x) |
| phi = self.max_pool(phi).view(n, -1, h * w // 4) |
|
|
| energy = torch.bmm(theta, phi) |
| attention = self.softmax(energy) |
|
|
| g = self.g(x) |
| g = self.max_pool(g).view(n, -1, h * w // 4) |
|
|
| out = torch.bmm(g, attention.permute(0, 2, 1)) |
| out = out.view(n, c // 2, h, w) |
| out = self.out_conv(out) |
|
|
| if self.clamp: |
| out = self.gamma.clamp(-1, 1) * out + x |
| else: |
| out = self.gamma * out + x |
| return out |
|
|