import torch import torch.nn as nn import math # From https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/3 def get_gaussian_kernel(kernel_size=3, pad=2, sigma=2, channels=3): # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) x_coord = torch.arange(kernel_size) x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size) y_grid = x_grid.t() xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() mean = (kernel_size - 1) / 2. variance = sigma ** 2. # Calculate the 2-dimensional gaussian kernel which is # the product of two gaussian distributions for two different # variables (in this case called x and y) gaussian_kernel = (1. / (2. * math.pi * variance)) * \ torch.exp( -torch.sum((xy_grid - mean) ** 2., dim=-1) / \ (2 * variance) ) # Make sure sum of values in gaussian kernel equals 1. gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) # Reshape to 2d depthwise convolutional weight gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1) gaussian_filter = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, groups=channels, padding=kernel_size-pad, bias=False) gaussian_filter.weight.data = gaussian_kernel gaussian_filter.weight.requires_grad = False return gaussian_filter