| |
| |
| |
| import torch |
| import torch.nn.functional as F |
| import torch.nn as nn |
|
|
| from functools import partial |
|
|
| |
| |
| |
| class LocalAffinity(nn.Module): |
|
|
| def __init__(self, dilations=[1]): |
| super(LocalAffinity, self).__init__() |
| self.dilations = dilations |
| weight = self._init_aff() |
| self.register_buffer('kernel', weight) |
|
|
| def _init_aff(self): |
| |
| weight = torch.zeros(8, 1, 3, 3) |
|
|
| for i in range(weight.size(0)): |
| weight[i, 0, 1, 1] = 1 |
|
|
| weight[0, 0, 0, 0] = -1 |
| weight[1, 0, 0, 1] = -1 |
| weight[2, 0, 0, 2] = -1 |
|
|
| weight[3, 0, 1, 0] = -1 |
| weight[4, 0, 1, 2] = -1 |
|
|
| weight[5, 0, 2, 0] = -1 |
| weight[6, 0, 2, 1] = -1 |
| weight[7, 0, 2, 2] = -1 |
|
|
| self.weight_check = weight.clone() |
|
|
| return weight |
|
|
| def forward(self, x): |
|
|
| self.weight_check = self.weight_check.type_as(x) |
| assert torch.all(self.weight_check.eq(self.kernel)) |
|
|
| B,K,H,W = x.size() |
| x = x.view(B*K,1,H,W) |
|
|
| x_affs = [] |
| for d in self.dilations: |
| x_pad = F.pad(x, [d]*4, mode='replicate') |
| x_aff = F.conv2d(x_pad, self.kernel, dilation=d) |
| x_affs.append(x_aff) |
|
|
| x_aff = torch.cat(x_affs, 1) |
| return x_aff.view(B,K,-1,H,W) |
|
|
| class LocalAffinityCopy(LocalAffinity): |
|
|
| def _init_aff(self): |
| |
| weight = torch.zeros(8, 1, 3, 3) |
|
|
| weight[0, 0, 0, 0] = 1 |
| weight[1, 0, 0, 1] = 1 |
| weight[2, 0, 0, 2] = 1 |
|
|
| weight[3, 0, 1, 0] = 1 |
| weight[4, 0, 1, 2] = 1 |
|
|
| weight[5, 0, 2, 0] = 1 |
| weight[6, 0, 2, 1] = 1 |
| weight[7, 0, 2, 2] = 1 |
|
|
| self.weight_check = weight.clone() |
| return weight |
|
|
| class LocalStDev(LocalAffinity): |
|
|
| def _init_aff(self): |
| weight = torch.zeros(9, 1, 3, 3) |
| weight.zero_() |
|
|
| weight[0, 0, 0, 0] = 1 |
| weight[1, 0, 0, 1] = 1 |
| weight[2, 0, 0, 2] = 1 |
|
|
| weight[3, 0, 1, 0] = 1 |
| weight[4, 0, 1, 1] = 1 |
| weight[5, 0, 1, 2] = 1 |
|
|
| weight[6, 0, 2, 0] = 1 |
| weight[7, 0, 2, 1] = 1 |
| weight[8, 0, 2, 2] = 1 |
|
|
| self.weight_check = weight.clone() |
| return weight |
|
|
| def forward(self, x): |
| |
| |
| x = super(LocalStDev, self).forward(x) |
|
|
| return x.std(2, keepdim=True) |
|
|
| class LocalAffinityAbs(LocalAffinity): |
|
|
| def forward(self, x): |
| x = super(LocalAffinityAbs, self).forward(x) |
| return torch.abs(x) |
|
|
| |
| |
| |
| class PAMR(nn.Module): |
|
|
| def __init__(self, num_iter=1, dilations=[1]): |
| super(PAMR, self).__init__() |
|
|
| self.num_iter = num_iter |
| self.aff_x = LocalAffinityAbs(dilations) |
| self.aff_m = LocalAffinityCopy(dilations) |
| self.aff_std = LocalStDev(dilations) |
|
|
| def forward(self, x, mask): |
| mask = F.interpolate(mask, size=x.size()[-2:], mode="bilinear", align_corners=True) |
|
|
| |
| |
| B,K,H,W = x.size() |
| _,C,_,_ = mask.size() |
|
|
| x_std = self.aff_std(x) |
|
|
| x = -self.aff_x(x) / (1e-8 + 0.1 * x_std) |
| x = x.mean(1, keepdim=True) |
| x = F.softmax(x, 2) |
|
|
| for _ in range(self.num_iter): |
| m = self.aff_m(mask) |
| mask = (m * x).sum(2) |
|
|
| |
| return mask |
|
|