|
|
import torch |
|
|
from torch import nn |
|
|
import torch.nn.functional as F |
|
|
from models.vgg import VGG_Backbone |
|
|
from util import * |
|
|
|
|
|
|
|
|
def weights_init(module): |
|
|
if isinstance(module, nn.Conv2d): |
|
|
nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu') |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)): |
|
|
nn.init.ones_(module.weight) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Linear): |
|
|
nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu') |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
|
|
|
class EnLayer(nn.Module): |
|
|
def __init__(self, in_channel=64): |
|
|
super(EnLayer, self).__init__() |
|
|
self.enlayer = nn.Sequential( |
|
|
nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.enlayer(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class LatLayer(nn.Module): |
|
|
def __init__(self, in_channel): |
|
|
super(LatLayer, self).__init__() |
|
|
self.convlayer = nn.Sequential( |
|
|
nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.convlayer(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class DSLayer(nn.Module): |
|
|
def __init__(self, in_channel=64): |
|
|
super(DSLayer, self).__init__() |
|
|
self.enlayer = nn.Sequential( |
|
|
nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
) |
|
|
self.predlayer = nn.Sequential( |
|
|
nn.Conv2d(64, 1, kernel_size=1, stride=1, padding=0)) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.enlayer(x) |
|
|
x = self.predlayer(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class half_DSLayer(nn.Module): |
|
|
def __init__(self, in_channel=512): |
|
|
super(half_DSLayer, self).__init__() |
|
|
self.enlayer = nn.Sequential( |
|
|
nn.Conv2d(in_channel, int(in_channel/4), kernel_size=3, stride=1, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
) |
|
|
self.predlayer = nn.Sequential( |
|
|
nn.Conv2d(int(in_channel/4), 1, kernel_size=1, stride=1, padding=0)) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.enlayer(x) |
|
|
x = self.predlayer(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class AugAttentionModule(nn.Module): |
|
|
def __init__(self, input_channels=512): |
|
|
super(AugAttentionModule, self).__init__() |
|
|
self.query_transform = nn.Sequential( |
|
|
nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0), |
|
|
nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0), |
|
|
) |
|
|
self.key_transform = nn.Sequential( |
|
|
nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0), |
|
|
nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0), |
|
|
) |
|
|
self.value_transform = nn.Sequential( |
|
|
nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0), |
|
|
nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0), |
|
|
) |
|
|
self.scale = 1.0 / (input_channels ** 0.5) |
|
|
self.conv = nn.Sequential( |
|
|
nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0), |
|
|
nn.ReLU(inplace=True), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
B, C, H, W = x.size() |
|
|
x = self.conv(x) |
|
|
x_query = self.query_transform(x).view(B, C, -1).permute(0, 2, 1) |
|
|
|
|
|
x_key = self.key_transform(x).view(B, C, -1) |
|
|
|
|
|
x_value = self.value_transform(x).view(B, C, -1).permute(0, 2, 1) |
|
|
attention_bmm = torch.bmm(x_query, x_key)*self.scale |
|
|
attention = F.softmax(attention_bmm, dim=-1) |
|
|
attention_sort = torch.sort(attention_bmm, dim=-1, descending=True)[1] |
|
|
attention_sort = torch.sort(attention_sort, dim=-1)[1] |
|
|
|
|
|
attention_positive_num = torch.ones_like(attention).cuda() |
|
|
attention_positive_num[attention_bmm < 0] = 0 |
|
|
att_pos_mask = attention_positive_num.clone() |
|
|
attention_positive_num = torch.sum(attention_positive_num, dim=-1, keepdim=True).expand_as(attention_sort) |
|
|
attention_sort_pos = attention_sort.float().clone() |
|
|
apn = attention_positive_num-1 |
|
|
attention_sort_pos[attention_sort > apn] = 0 |
|
|
attention_mask = ((attention_sort_pos+1)**3)*att_pos_mask + (1-att_pos_mask) |
|
|
out = torch.bmm(attention*attention_mask, x_value) |
|
|
out = out.view(B, H, W, C).permute(0, 3, 1, 2) |
|
|
return out+x |
|
|
|
|
|
|
|
|
class AttLayer(nn.Module): |
|
|
def __init__(self, input_channels=512): |
|
|
super(AttLayer, self).__init__() |
|
|
self.query_transform = nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0) |
|
|
self.key_transform = nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0) |
|
|
self.scale = 1.0 / (input_channels ** 0.5) |
|
|
self.conv = nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0) |
|
|
|
|
|
def correlation(self, x5, seeds): |
|
|
B, C, H5, W5 = x5.size() |
|
|
if self.training: |
|
|
correlation_maps = F.conv2d(x5, weight=seeds) |
|
|
else: |
|
|
correlation_maps = torch.relu(F.conv2d(x5, weight=seeds)) |
|
|
correlation_maps = correlation_maps.mean(1).view(B, -1) |
|
|
min_value = torch.min(correlation_maps, dim=1, keepdim=True)[0] |
|
|
max_value = torch.max(correlation_maps, dim=1, keepdim=True)[0] |
|
|
correlation_maps = (correlation_maps - min_value) / (max_value - min_value + 1e-12) |
|
|
correlation_maps = correlation_maps.view(B, 1, H5, W5) |
|
|
return correlation_maps |
|
|
|
|
|
def forward(self, x5): |
|
|
|
|
|
x5 = self.conv(x5)+x5 |
|
|
B, C, H5, W5 = x5.size() |
|
|
x_query = self.query_transform(x5).view(B, C, -1) |
|
|
|
|
|
x_query = torch.transpose(x_query, 1, 2).contiguous().view(-1, C) |
|
|
|
|
|
x_key = self.key_transform(x5).view(B, C, -1) |
|
|
x_key = torch.transpose(x_key, 0, 1).contiguous().view(C, -1) |
|
|
|
|
|
x_w1 = torch.matmul(x_query, x_key) * self.scale |
|
|
x_w = x_w1.view(B * H5 * W5, B, H5 * W5) |
|
|
x_w = torch.max(x_w, -1).values |
|
|
x_w = x_w.mean(-1) |
|
|
x_w = x_w.view(B, -1) |
|
|
x_w = F.softmax(x_w, dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
norm0 = F.normalize(x5, dim=1) |
|
|
|
|
|
|
|
|
|
|
|
x_w = x_w.unsqueeze(1) |
|
|
x_w_max = torch.max(x_w, -1).values.unsqueeze(2).expand_as(x_w) |
|
|
mask = torch.zeros_like(x_w).cuda() |
|
|
mask[x_w == x_w_max] = 1 |
|
|
mask = mask.view(B, 1, H5, W5) |
|
|
seeds = norm0 * mask |
|
|
seeds = seeds.sum(3).sum(2).unsqueeze(2).unsqueeze(3) |
|
|
cormap = self.correlation(norm0, seeds) |
|
|
x51 = x5 * cormap |
|
|
proto1 = torch.mean(x51, (0, 2, 3), True) |
|
|
return x5, proto1, x5*proto1+x51, mask |
|
|
|
|
|
|
|
|
class Decoder(nn.Module): |
|
|
def __init__(self): |
|
|
super(Decoder, self).__init__() |
|
|
self.toplayer = nn.Sequential( |
|
|
nn.Conv2d(512, 64, kernel_size=1, stride=1, padding=0), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0)) |
|
|
self.latlayer4 = LatLayer(in_channel=512) |
|
|
self.latlayer3 = LatLayer(in_channel=256) |
|
|
self.latlayer2 = LatLayer(in_channel=128) |
|
|
self.latlayer1 = LatLayer(in_channel=64) |
|
|
|
|
|
self.enlayer4 = EnLayer() |
|
|
self.enlayer3 = EnLayer() |
|
|
self.enlayer2 = EnLayer() |
|
|
self.enlayer1 = EnLayer() |
|
|
|
|
|
self.dslayer4 = DSLayer() |
|
|
self.dslayer3 = DSLayer() |
|
|
self.dslayer2 = DSLayer() |
|
|
self.dslayer1 = DSLayer() |
|
|
|
|
|
def _upsample_add(self, x, y): |
|
|
[_, _, H, W] = y.size() |
|
|
x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=False) |
|
|
return x + y |
|
|
|
|
|
def forward(self, weighted_x5, x4, x3, x2, x1, H, W): |
|
|
preds = [] |
|
|
p5 = self.toplayer(weighted_x5) |
|
|
p4 = self._upsample_add(p5, self.latlayer4(x4)) |
|
|
p4 = self.enlayer4(p4) |
|
|
_pred = self.dslayer4(p4) |
|
|
preds.append( |
|
|
F.interpolate(_pred, |
|
|
size=(H, W), |
|
|
mode='bilinear', align_corners=False)) |
|
|
|
|
|
p3 = self._upsample_add(p4, self.latlayer3(x3)) |
|
|
p3 = self.enlayer3(p3) |
|
|
_pred = self.dslayer3(p3) |
|
|
preds.append( |
|
|
F.interpolate(_pred, |
|
|
size=(H, W), |
|
|
mode='bilinear', align_corners=False)) |
|
|
|
|
|
p2 = self._upsample_add(p3, self.latlayer2(x2)) |
|
|
p2 = self.enlayer2(p2) |
|
|
_pred = self.dslayer2(p2) |
|
|
preds.append( |
|
|
F.interpolate(_pred, |
|
|
size=(H, W), |
|
|
mode='bilinear', align_corners=False)) |
|
|
|
|
|
p1 = self._upsample_add(p2, self.latlayer1(x1)) |
|
|
p1 = self.enlayer1(p1) |
|
|
_pred = self.dslayer1(p1) |
|
|
preds.append( |
|
|
F.interpolate(_pred, |
|
|
size=(H, W), |
|
|
mode='bilinear', align_corners=False)) |
|
|
return preds |
|
|
|
|
|
|
|
|
class DCFMNet(nn.Module): |
|
|
""" Class for extracting activations and |
|
|
registering gradients from targetted intermediate layers """ |
|
|
def __init__(self, mode='train'): |
|
|
super(DCFMNet, self).__init__() |
|
|
self.gradients = None |
|
|
self.backbone = VGG_Backbone() |
|
|
self.mode = mode |
|
|
self.aug = AugAttentionModule() |
|
|
self.fusion = AttLayer(512) |
|
|
self.decoder = Decoder() |
|
|
|
|
|
def set_mode(self, mode): |
|
|
self.mode = mode |
|
|
|
|
|
def forward(self, x, gt): |
|
|
if self.mode == 'train': |
|
|
preds = self._forward(x, gt) |
|
|
else: |
|
|
with torch.no_grad(): |
|
|
preds = self._forward(x, gt) |
|
|
|
|
|
return preds |
|
|
|
|
|
def featextract(self, x): |
|
|
x1 = self.backbone.conv1(x) |
|
|
x2 = self.backbone.conv2(x1) |
|
|
x3 = self.backbone.conv3(x2) |
|
|
x4 = self.backbone.conv4(x3) |
|
|
x5 = self.backbone.conv5(x4) |
|
|
return x5, x4, x3, x2, x1 |
|
|
|
|
|
def _forward(self, x, gt): |
|
|
[B, _, H, W] = x.size() |
|
|
x5, x4, x3, x2, x1 = self.featextract(x) |
|
|
feat, proto, weighted_x5, cormap = self.fusion(x5) |
|
|
feataug = self.aug(weighted_x5) |
|
|
preds = self.decoder(feataug, x4, x3, x2, x1, H, W) |
|
|
if self.training: |
|
|
gt = F.interpolate(gt, size=weighted_x5.size()[2:], mode='bilinear', align_corners=False) |
|
|
feat_pos, proto_pos, weighted_x5_pos, cormap_pos = self.fusion(x5 * gt) |
|
|
feat_neg, proto_neg, weighted_x5_neg, cormap_neg = self.fusion(x5*(1-gt)) |
|
|
return preds, proto, proto_pos, proto_neg |
|
|
return preds |
|
|
|
|
|
|
|
|
class DCFM(nn.Module): |
|
|
def __init__(self, mode='train'): |
|
|
super(DCFM, self).__init__() |
|
|
set_seed(123) |
|
|
self.dcfmnet = DCFMNet() |
|
|
self.mode = mode |
|
|
|
|
|
def set_mode(self, mode): |
|
|
self.mode = mode |
|
|
self.dcfmnet.set_mode(self.mode) |
|
|
|
|
|
def forward(self, x, gt): |
|
|
|
|
|
preds = self.dcfmnet(x, gt) |
|
|
return preds |
|
|
|
|
|
|