|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn import functional as F |
|
|
from torch.autograd import Variable |
|
|
|
|
|
def diff_x(input, r): |
|
|
assert input.dim() == 4 |
|
|
|
|
|
left = input[:, :, r:2 * r + 1] |
|
|
middle = input[:, :, 2 * r + 1: ] - input[:, :, :-2 * r - 1] |
|
|
right = input[:, :, -1: ] - input[:, :, -2 * r - 1: -r - 1] |
|
|
|
|
|
output = torch.cat([left, middle, right], dim=2) |
|
|
|
|
|
return output |
|
|
|
|
|
def diff_y(input, r): |
|
|
assert input.dim() == 4 |
|
|
|
|
|
left = input[:, :, :, r:2 * r + 1] |
|
|
middle = input[:, :, :, 2 * r + 1: ] - input[:, :, :, :-2 * r - 1] |
|
|
right = input[:, :, :, -1: ] - input[:, :, :, -2 * r - 1: -r - 1] |
|
|
|
|
|
output = torch.cat([left, middle, right], dim=3) |
|
|
|
|
|
return output |
|
|
|
|
|
class BoxFilter(nn.Module): |
|
|
def __init__(self, r): |
|
|
super(BoxFilter, self).__init__() |
|
|
|
|
|
self.r = r |
|
|
|
|
|
def forward(self, x): |
|
|
assert x.dim() == 4 |
|
|
|
|
|
return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r) |
|
|
|
|
|
|
|
|
class FastGuidedFilter(nn.Module): |
|
|
def __init__(self, r, eps=1e-8): |
|
|
super(FastGuidedFilter, self).__init__() |
|
|
|
|
|
self.r = r |
|
|
self.eps = eps |
|
|
self.boxfilter = BoxFilter(r) |
|
|
|
|
|
|
|
|
def forward(self, lr_x, lr_y, hr_x): |
|
|
n_lrx, c_lrx, h_lrx, w_lrx = lr_x.size() |
|
|
n_lry, c_lry, h_lry, w_lry = lr_y.size() |
|
|
n_hrx, c_hrx, h_hrx, w_hrx = hr_x.size() |
|
|
|
|
|
assert n_lrx == n_lry and n_lry == n_hrx |
|
|
assert c_lrx == c_hrx and (c_lrx == 1 or c_lrx == c_lry) |
|
|
assert h_lrx == h_lry and w_lrx == w_lry |
|
|
assert h_lrx > 2*self.r+1 and w_lrx > 2*self.r+1 |
|
|
|
|
|
|
|
|
N = self.boxfilter(Variable(lr_x.data.new().resize_((1, 1, h_lrx, w_lrx)).fill_(1.0))) |
|
|
|
|
|
|
|
|
mean_x = self.boxfilter(lr_x) / N |
|
|
|
|
|
mean_y = self.boxfilter(lr_y) / N |
|
|
|
|
|
cov_xy = self.boxfilter(lr_x * lr_y) / N - mean_x * mean_y |
|
|
|
|
|
var_x = self.boxfilter(lr_x * lr_x) / N - mean_x * mean_x |
|
|
|
|
|
|
|
|
A = cov_xy / (var_x + self.eps) |
|
|
|
|
|
b = mean_y - A * mean_x |
|
|
|
|
|
|
|
|
mean_A = F.interpolate(A, (h_hrx, w_hrx), mode='bilinear', align_corners=True) |
|
|
mean_b = F.interpolate(b, (h_hrx, w_hrx), mode='bilinear', align_corners=True) |
|
|
|
|
|
return mean_A*hr_x+mean_b |