Spaces:
Running
Running
| # pylint: disable=missing-module-docstring,invalid-name | |
| # pylint: disable=missing-docstring | |
| # pylint: disable=line-too-long | |
| import math | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class LayerNorm(nn.Module): | |
| r"""Applies Layer Normalization over a mini-batch of inputs as described in | |
| the paper `Layer Normalization`_ . | |
| .. math:: | |
| y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta | |
| The mean and standard-deviation are calculated separately over the last | |
| certain number dimensions which have to be of the shape specified by | |
| :attr:`normalized_shape`. | |
| :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of | |
| :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``. | |
| .. note:: | |
| Unlike Batch Normalization and Instance Normalization, which applies | |
| scalar scale and bias for each entire channel/plane with the | |
| :attr:`affine` option, Layer Normalization applies per-element scale and | |
| bias with :attr:`elementwise_affine`. | |
| This layer uses statistics computed from input data in both training and | |
| evaluation modes. | |
| Args: | |
| normalized_shape (int or list or torch.Size): input shape from an expected input | |
| of size | |
| .. math:: | |
| [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] | |
| \times \ldots \times \text{normalized\_shape}[-1]] | |
| If a single integer is used, it is treated as a singleton list, and this module will | |
| normalize over the last dimension which is expected to be of that specific size. | |
| eps: a value added to the denominator for numerical stability. Default: 1e-5 | |
| elementwise_affine: a boolean value that when set to ``True``, this module | |
| has learnable per-element affine parameters initialized to ones (for weights) | |
| and zeros (for biases). Default: ``True``. | |
| Shape: | |
| - Input: :math:`(N, *)` | |
| - Output: :math:`(N, *)` (same shape as input) | |
| Examples:: | |
| >>> input = torch.randn(20, 5, 10, 10) | |
| >>> # With Learnable Parameters | |
| >>> m = nn.LayerNorm(input.size()[1:]) | |
| >>> # Without Learnable Parameters | |
| >>> m = nn.LayerNorm(input.size()[1:], elementwise_affine=False) | |
| >>> # Normalize over last two dimensions | |
| >>> m = nn.LayerNorm([10, 10]) | |
| >>> # Normalize over last dimension of size 10 | |
| >>> m = nn.LayerNorm(10) | |
| >>> # Activating the module | |
| >>> output = m(input) | |
| .. _`Layer Normalization`: https://arxiv.org/abs/1607.06450 | |
| """ | |
| __constants__ = ['features', 'weight', 'bias', 'eps', 'center', 'scale'] | |
| def __init__(self, features, eps=1e-12, center=True, scale=True): | |
| super(LayerNorm, self).__init__() | |
| self.features = features | |
| self.eps = eps | |
| self.center = center | |
| self.scale = scale | |
| if self.scale: | |
| self.weight = nn.Parameter(torch.Tensor(self.features)) | |
| else: | |
| self.register_parameter('weight', None) | |
| if self.center: | |
| self.bias = nn.Parameter(torch.Tensor(self.features)) | |
| else: | |
| self.register_parameter('bias', None) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| if self.scale: | |
| nn.init.ones_(self.weight) | |
| if self.center: | |
| nn.init.zeros_(self.bias) | |
| def adjust_parameter(self, tensor, parameter): | |
| return torch.repeat_interleave( | |
| torch.repeat_interleave( | |
| parameter.view(-1, 1, 1), | |
| repeats=tensor.shape[2], | |
| dim=1), | |
| repeats=tensor.shape[3], | |
| dim=2 | |
| ) | |
| def forward(self, input): | |
| normalized_shape = (self.features, input.shape[2], input.shape[3]) | |
| weight = self.adjust_parameter(input, self.weight) | |
| bias = self.adjust_parameter(input, self.bias) | |
| return F.layer_norm( | |
| input, normalized_shape, weight, bias, self.eps) | |
| def extra_repr(self): | |
| return '{features}, eps={eps}, ' \ | |
| 'center={center}, scale={scale}'.format(**self.__dict__) | |
| def gaussian_filter_1d(tensor, dim, sigma, truncate=4, kernel_size=None, padding_mode='replicate', padding_value=0.0): | |
| sigma = torch.as_tensor(sigma, device=tensor.device, dtype=tensor.dtype) | |
| if kernel_size is not None: | |
| kernel_size = torch.as_tensor(kernel_size, device=tensor.device, dtype=torch.int64) | |
| else: | |
| kernel_size = torch.as_tensor(2 * torch.ceil(truncate * sigma) + 1, device=tensor.device, dtype=torch.int64) | |
| kernel_size = kernel_size.detach() | |
| kernel_size_int = kernel_size.detach().cpu().numpy() | |
| mean = (torch.as_tensor(kernel_size, dtype=tensor.dtype) - 1) / 2 | |
| grid = torch.arange(kernel_size, device=tensor.device) - mean | |
| kernel_shape = (1, 1, kernel_size) | |
| grid = grid.view(kernel_shape) | |
| grid = grid.detach() | |
| source_shape = tensor.shape | |
| tensor = torch.movedim(tensor, dim, len(source_shape)-1) | |
| dim_last_shape = tensor.shape | |
| assert tensor.shape[-1] == source_shape[dim] | |
| # we need reshape instead of view for batches like B x C x H x W | |
| tensor = tensor.reshape(-1, 1, source_shape[dim]) | |
| padding = (math.ceil((kernel_size_int - 1) / 2), math.ceil((kernel_size_int - 1) / 2)) | |
| tensor_ = F.pad(tensor, padding, padding_mode, padding_value) | |
| # create gaussian kernel from grid using current sigma | |
| kernel = torch.exp(-0.5 * (grid / sigma) ** 2) | |
| kernel = kernel / kernel.sum() | |
| # convolve input with gaussian kernel | |
| tensor_ = F.conv1d(tensor_, kernel) | |
| tensor_ = tensor_.view(dim_last_shape) | |
| tensor_ = torch.movedim(tensor_, len(source_shape)-1, dim) | |
| assert tensor_.shape == source_shape | |
| return tensor_ | |
| class GaussianFilterNd(nn.Module): | |
| """A differentiable gaussian filter""" | |
| def __init__(self, dims, sigma, truncate=4, kernel_size=None, padding_mode='replicate', padding_value=0.0, | |
| trainable=False): | |
| """Creates a 1d gaussian filter | |
| Args: | |
| dims ([int]): the dimensions to which the gaussian filter is applied. Negative values won't work | |
| sigma (float): standard deviation of the gaussian filter (blur size) | |
| input_dims (int, optional): number of input dimensions ignoring batch and channel dimension, | |
| i.e. use input_dims=2 for images (default: 2). | |
| truncate (float, optional): truncate the filter at this many standard deviations (default: 4.0). | |
| This has no effect if the `kernel_size` is explicitely set | |
| kernel_size (int): size of the gaussian kernel convolved with the input | |
| padding_mode (string, optional): Padding mode implemented by `torch.nn.functional.pad`. | |
| padding_value (string, optional): Value used for constant padding. | |
| """ | |
| # IDEA determine input_dims dynamically for every input | |
| super(GaussianFilterNd, self).__init__() | |
| self.dims = dims | |
| self.sigma = nn.Parameter(torch.tensor(sigma, dtype=torch.float32), requires_grad=trainable) # default: no optimization | |
| self.truncate = truncate | |
| self.kernel_size = kernel_size | |
| # setup padding | |
| self.padding_mode = padding_mode | |
| self.padding_value = padding_value | |
| def forward(self, tensor): | |
| """Applies the gaussian filter to the given tensor""" | |
| for dim in self.dims: | |
| tensor = gaussian_filter_1d( | |
| tensor, | |
| dim=dim, | |
| sigma=self.sigma, | |
| truncate=self.truncate, | |
| kernel_size=self.kernel_size, | |
| padding_mode=self.padding_mode, | |
| padding_value=self.padding_value, | |
| ) | |
| return tensor | |
| class Conv2dMultiInput(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, bias=True): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| for k, _in_channels in enumerate(in_channels): | |
| if _in_channels: | |
| setattr(self, f'conv_part{k}', nn.Conv2d(_in_channels, out_channels, kernel_size, bias=bias)) | |
| def forward(self, tensors): | |
| assert len(tensors) == len(self.in_channels) | |
| out = None | |
| for k, (count, tensor) in enumerate(zip(self.in_channels, tensors)): | |
| if not count: | |
| continue | |
| _out = getattr(self, f'conv_part{k}')(tensor) | |
| if out is None: | |
| out = _out | |
| else: | |
| out += _out | |
| return out | |
| # def extra_repr(self): | |
| # return f'{self.in_channels}' | |
| class LayerNormMultiInput(nn.Module): | |
| __constants__ = ['features', 'weight', 'bias', 'eps', 'center', 'scale'] | |
| def __init__(self, features, eps=1e-12, center=True, scale=True): | |
| super().__init__() | |
| self.features = features | |
| self.eps = eps | |
| self.center = center | |
| self.scale = scale | |
| for k, _features in enumerate(features): | |
| if _features: | |
| setattr(self, f'layernorm_part{k}', LayerNorm(_features, eps=eps, center=center, scale=scale)) | |
| def forward(self, tensors): | |
| assert len(tensors) == len(self.features) | |
| out = [] | |
| for k, (count, tensor) in enumerate(zip(self.features, tensors)): | |
| if not count: | |
| assert tensor is None | |
| out.append(None) | |
| continue | |
| out.append(getattr(self, f'layernorm_part{k}')(tensor)) | |
| return out | |
| class Bias(nn.Module): | |
| def __init__(self, channels): | |
| super().__init__() | |
| self.channels = channels | |
| self.bias = nn.Parameter(torch.zeros(channels)) | |
| def forward(self, tensor): | |
| return tensor + self.bias[np.newaxis, :, np.newaxis, np.newaxis] | |
| def extra_repr(self): | |
| return f'channels={self.channels}' | |
| class SelfAttention(nn.Module): | |
| """ Self attention Layer | |
| adapted from https://discuss.pytorch.org/t/attention-in-image-classification/80147/3 | |
| """ | |
| def __init__(self, in_channels, out_channels=None, key_channels=None, activation=None, skip_connection_with_convolution=False, return_attention=True): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| if out_channels is None: | |
| out_channels = in_channels | |
| self.out_channels = out_channels | |
| if key_channels is None: | |
| key_channels = in_channels // 8 | |
| self.key_channels = key_channels | |
| self.activation = activation | |
| self.skip_connection_with_convolution = skip_connection_with_convolution | |
| if not self.skip_connection_with_convolution: | |
| if self.out_channels != self.in_channels: | |
| raise ValueError("out_channels has to be equal to in_channels with true skip connection!") | |
| self.return_attention = return_attention | |
| self.query_conv = nn.Conv2d(in_channels=in_channels, out_channels=key_channels, kernel_size=1) | |
| self.key_conv = nn.Conv2d(in_channels=in_channels, out_channels=key_channels, kernel_size=1) | |
| self.value_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) | |
| self.gamma = nn.Parameter(torch.zeros(1)) | |
| if self.skip_connection_with_convolution: | |
| self.skip_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) | |
| self.softmax = nn.Softmax(dim=-1) | |
| def forward(self, x): | |
| """ | |
| inputs : | |
| x : input feature maps( B X C X W X H) | |
| returns : | |
| out : self attention value + input feature | |
| attention: B X N X N (N is Width*Height) | |
| """ | |
| m_batchsize, C, width, height = x.size() | |
| proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X CX(N) | |
| proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H) | |
| energy = torch.bmm(proj_query, proj_key) # transpose check | |
| attention = self.softmax(energy) # BX (N) X (N) | |
| proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N | |
| out = torch.bmm(proj_value, attention.permute(0, 2, 1)) | |
| out = out.view(m_batchsize, self.out_channels, width, height) | |
| if self.skip_connection_with_convolution: | |
| skip_connection = self.skip_conv(x) | |
| else: | |
| skip_connection = x | |
| out = self.gamma * out + skip_connection | |
| if self.activation is not None: | |
| out = self.activation(out) | |
| if self.return_attention: | |
| return out, attention | |
| return out | |
| class MultiHeadSelfAttention(nn.Module): | |
| """ Self attention Layer | |
| adapted from https://discuss.pytorch.org/t/attention-in-image-classification/80147/3 | |
| """ | |
| def __init__(self, in_channels, heads, out_channels=None, key_channels=None, activation=None, skip_connection_with_convolution=False): | |
| super().__init__() | |
| self.heads = heads | |
| self.heads = nn.ModuleList([SelfAttention( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| key_channels=key_channels, | |
| activation=activation, | |
| skip_connection_with_convolution=skip_connection_with_convolution, | |
| return_attention=False, | |
| ) for _ in range(heads)]) | |
| def forward(self, tensor): | |
| outs = [head(tensor) for head in self.heads] | |
| out = torch.cat(outs, dim=1) | |
| return out | |
| class FlexibleScanpathHistoryEncoding(nn.Module): | |
| """ | |
| a convolutional layer which works for different numbers of previous fixations. | |
| Nonexistent fixations will deactivate the respective convolutions | |
| the bias will be added per fixation (if the given fixation is present) | |
| """ | |
| def __init__(self, in_fixations, channels_per_fixation, out_channels, kernel_size, bias=True,): | |
| super().__init__() | |
| self.in_fixations = in_fixations | |
| self.channels_per_fixation = channels_per_fixation | |
| self.out_channels = out_channels | |
| self.kernel_size = kernel_size | |
| self.bias = bias | |
| self.convolutions = nn.ModuleList([ | |
| nn.Conv2d( | |
| in_channels=self.channels_per_fixation, | |
| out_channels=self.out_channels, | |
| kernel_size=self.kernel_size, | |
| bias=self.bias | |
| ) for i in range(in_fixations) | |
| ]) | |
| def forward(self, tensor): | |
| results = None | |
| valid_fixations = ~torch.isnan( | |
| tensor[:, :self.in_fixations, 0, 0] | |
| ) | |
| # print("valid fix", valid_fixations) | |
| for fixation_index in range(self.in_fixations): | |
| valid_indices = valid_fixations[:, fixation_index] | |
| if not torch.any(valid_indices): | |
| continue | |
| this_input = tensor[ | |
| valid_indices, | |
| fixation_index::self.in_fixations | |
| ] | |
| this_result = self.convolutions[fixation_index]( | |
| this_input | |
| ) | |
| # TODO: This will break if all data points | |
| # in the batch don't have a single fixation | |
| # but that's not a case I intend to train | |
| # anyway. | |
| if results is None: | |
| b, _, _, _ = tensor.shape | |
| _, _, h, w = this_result.shape | |
| results = torch.zeros( | |
| (b, self.out_channels, h, w), | |
| dtype=tensor.dtype, | |
| device=tensor.device | |
| ) | |
| results[valid_indices] += this_result | |
| return results | |