|
|
import torch |
|
|
import torch.nn as nn |
|
|
import math |
|
|
|
|
|
|
|
|
|
|
|
def get_gaussian_kernel(kernel_size=3, pad=2, sigma=2, channels=3): |
|
|
|
|
|
x_coord = torch.arange(kernel_size) |
|
|
x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size) |
|
|
y_grid = x_grid.t() |
|
|
xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() |
|
|
|
|
|
mean = (kernel_size - 1) / 2. |
|
|
variance = sigma ** 2. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gaussian_kernel = (1. / (2. * math.pi * variance)) * \ |
|
|
torch.exp( |
|
|
-torch.sum((xy_grid - mean) ** 2., dim=-1) / \ |
|
|
(2 * variance) |
|
|
) |
|
|
|
|
|
|
|
|
gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) |
|
|
|
|
|
|
|
|
gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) |
|
|
gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1) |
|
|
|
|
|
gaussian_filter = nn.Conv2d(in_channels=channels, out_channels=channels, |
|
|
kernel_size=kernel_size, groups=channels, padding=kernel_size-pad, bias=False) |
|
|
|
|
|
gaussian_filter.weight.data = gaussian_kernel |
|
|
gaussian_filter.weight.requires_grad = False |
|
|
|
|
|
return gaussian_filter |
|
|
|