| import torch |
| from torch import nn |
| import torch.nn.functional as F |
| from models.vgg import VGG_Backbone |
| from util import * |
|
|
|
|
| def weights_init(module): |
| if isinstance(module, nn.Conv2d): |
| nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu') |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)): |
| nn.init.ones_(module.weight) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Linear): |
| nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu') |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
|
|
| class EnLayer(nn.Module): |
| def __init__(self, in_channel=64): |
| super(EnLayer, self).__init__() |
| self.enlayer = nn.Sequential( |
| nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), |
| ) |
|
|
| def forward(self, x): |
| x = self.enlayer(x) |
| return x |
|
|
|
|
| class LatLayer(nn.Module): |
| def __init__(self, in_channel): |
| super(LatLayer, self).__init__() |
| self.convlayer = nn.Sequential( |
| nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), |
| ) |
|
|
| def forward(self, x): |
| x = self.convlayer(x) |
| return x |
|
|
|
|
| class DSLayer(nn.Module): |
| def __init__(self, in_channel=64): |
| super(DSLayer, self).__init__() |
| self.enlayer = nn.Sequential( |
| nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), |
| nn.ReLU(inplace=True), |
| ) |
| self.predlayer = nn.Sequential( |
| nn.Conv2d(64, 1, kernel_size=1, stride=1, padding=0)) |
|
|
| def forward(self, x): |
| x = self.enlayer(x) |
| x = self.predlayer(x) |
| return x |
|
|
|
|
| class half_DSLayer(nn.Module): |
| def __init__(self, in_channel=512): |
| super(half_DSLayer, self).__init__() |
| self.enlayer = nn.Sequential( |
| nn.Conv2d(in_channel, int(in_channel/4), kernel_size=3, stride=1, padding=1), |
| nn.ReLU(inplace=True), |
| ) |
| self.predlayer = nn.Sequential( |
| nn.Conv2d(int(in_channel/4), 1, kernel_size=1, stride=1, padding=0)) |
|
|
| def forward(self, x): |
| x = self.enlayer(x) |
| x = self.predlayer(x) |
| return x |
|
|
|
|
| class AugAttentionModule(nn.Module): |
| def __init__(self, input_channels=512): |
| super(AugAttentionModule, self).__init__() |
| self.query_transform = nn.Sequential( |
| nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0), |
| nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0), |
| ) |
| self.key_transform = nn.Sequential( |
| nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0), |
| nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0), |
| ) |
| self.value_transform = nn.Sequential( |
| nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0), |
| nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0), |
| ) |
| self.scale = 1.0 / (input_channels ** 0.5) |
| self.conv = nn.Sequential( |
| nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0), |
| nn.ReLU(inplace=True), |
| ) |
|
|
| def forward(self, x): |
| B, C, H, W = x.size() |
| x = self.conv(x) |
| x_query = self.query_transform(x).view(B, C, -1).permute(0, 2, 1) |
| |
| x_key = self.key_transform(x).view(B, C, -1) |
| |
| x_value = self.value_transform(x).view(B, C, -1).permute(0, 2, 1) |
| attention_bmm = torch.bmm(x_query, x_key)*self.scale |
| attention = F.softmax(attention_bmm, dim=-1) |
| attention_sort = torch.sort(attention_bmm, dim=-1, descending=True)[1] |
| attention_sort = torch.sort(attention_sort, dim=-1)[1] |
| |
| attention_positive_num = torch.ones_like(attention).cuda() |
| attention_positive_num[attention_bmm < 0] = 0 |
| att_pos_mask = attention_positive_num.clone() |
| attention_positive_num = torch.sum(attention_positive_num, dim=-1, keepdim=True).expand_as(attention_sort) |
| attention_sort_pos = attention_sort.float().clone() |
| apn = attention_positive_num-1 |
| attention_sort_pos[attention_sort > apn] = 0 |
| attention_mask = ((attention_sort_pos+1)**3)*att_pos_mask + (1-att_pos_mask) |
| out = torch.bmm(attention*attention_mask, x_value) |
| out = out.view(B, H, W, C).permute(0, 3, 1, 2) |
| return out+x |
|
|
|
|
| class AttLayer(nn.Module): |
| def __init__(self, input_channels=512): |
| super(AttLayer, self).__init__() |
| self.query_transform = nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0) |
| self.key_transform = nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0) |
| self.scale = 1.0 / (input_channels ** 0.5) |
| self.conv = nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0) |
|
|
| def correlation(self, x5, seeds): |
| B, C, H5, W5 = x5.size() |
| if self.training: |
| correlation_maps = F.conv2d(x5, weight=seeds) |
| else: |
| correlation_maps = torch.relu(F.conv2d(x5, weight=seeds)) |
| correlation_maps = correlation_maps.mean(1).view(B, -1) |
| min_value = torch.min(correlation_maps, dim=1, keepdim=True)[0] |
| max_value = torch.max(correlation_maps, dim=1, keepdim=True)[0] |
| correlation_maps = (correlation_maps - min_value) / (max_value - min_value + 1e-12) |
| correlation_maps = correlation_maps.view(B, 1, H5, W5) |
| return correlation_maps |
|
|
| def forward(self, x5): |
| |
| x5 = self.conv(x5)+x5 |
| B, C, H5, W5 = x5.size() |
| x_query = self.query_transform(x5).view(B, C, -1) |
| |
| x_query = torch.transpose(x_query, 1, 2).contiguous().view(-1, C) |
| |
| x_key = self.key_transform(x5).view(B, C, -1) |
| x_key = torch.transpose(x_key, 0, 1).contiguous().view(C, -1) |
| |
| x_w1 = torch.matmul(x_query, x_key) * self.scale |
| x_w = x_w1.view(B * H5 * W5, B, H5 * W5) |
| x_w = torch.max(x_w, -1).values |
| x_w = x_w.mean(-1) |
| x_w = x_w.view(B, -1) |
| x_w = F.softmax(x_w, dim=-1) |
| |
| |
| |
| norm0 = F.normalize(x5, dim=1) |
| |
| |
| |
| x_w = x_w.unsqueeze(1) |
| x_w_max = torch.max(x_w, -1).values.unsqueeze(2).expand_as(x_w) |
| mask = torch.zeros_like(x_w).cuda() |
| mask[x_w == x_w_max] = 1 |
| mask = mask.view(B, 1, H5, W5) |
| seeds = norm0 * mask |
| seeds = seeds.sum(3).sum(2).unsqueeze(2).unsqueeze(3) |
| cormap = self.correlation(norm0, seeds) |
| x51 = x5 * cormap |
| proto1 = torch.mean(x51, (0, 2, 3), True) |
| return x5, proto1, x5*proto1+x51, mask |
|
|
|
|
| class Decoder(nn.Module): |
| def __init__(self): |
| super(Decoder, self).__init__() |
| self.toplayer = nn.Sequential( |
| nn.Conv2d(512, 64, kernel_size=1, stride=1, padding=0), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0)) |
| self.latlayer4 = LatLayer(in_channel=512) |
| self.latlayer3 = LatLayer(in_channel=256) |
| self.latlayer2 = LatLayer(in_channel=128) |
| self.latlayer1 = LatLayer(in_channel=64) |
|
|
| self.enlayer4 = EnLayer() |
| self.enlayer3 = EnLayer() |
| self.enlayer2 = EnLayer() |
| self.enlayer1 = EnLayer() |
|
|
| self.dslayer4 = DSLayer() |
| self.dslayer3 = DSLayer() |
| self.dslayer2 = DSLayer() |
| self.dslayer1 = DSLayer() |
|
|
| def _upsample_add(self, x, y): |
| [_, _, H, W] = y.size() |
| x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=False) |
| return x + y |
|
|
| def forward(self, weighted_x5, x4, x3, x2, x1, H, W): |
| preds = [] |
| p5 = self.toplayer(weighted_x5) |
| p4 = self._upsample_add(p5, self.latlayer4(x4)) |
| p4 = self.enlayer4(p4) |
| _pred = self.dslayer4(p4) |
| preds.append( |
| F.interpolate(_pred, |
| size=(H, W), |
| mode='bilinear', align_corners=False)) |
|
|
| p3 = self._upsample_add(p4, self.latlayer3(x3)) |
| p3 = self.enlayer3(p3) |
| _pred = self.dslayer3(p3) |
| preds.append( |
| F.interpolate(_pred, |
| size=(H, W), |
| mode='bilinear', align_corners=False)) |
|
|
| p2 = self._upsample_add(p3, self.latlayer2(x2)) |
| p2 = self.enlayer2(p2) |
| _pred = self.dslayer2(p2) |
| preds.append( |
| F.interpolate(_pred, |
| size=(H, W), |
| mode='bilinear', align_corners=False)) |
|
|
| p1 = self._upsample_add(p2, self.latlayer1(x1)) |
| p1 = self.enlayer1(p1) |
| _pred = self.dslayer1(p1) |
| preds.append( |
| F.interpolate(_pred, |
| size=(H, W), |
| mode='bilinear', align_corners=False)) |
| return preds |
|
|
|
|
| class DCFMNet(nn.Module): |
| """ Class for extracting activations and |
| registering gradients from targetted intermediate layers """ |
| def __init__(self, mode='train'): |
| super(DCFMNet, self).__init__() |
| self.gradients = None |
| self.backbone = VGG_Backbone() |
| self.mode = mode |
| self.aug = AugAttentionModule() |
| self.fusion = AttLayer(512) |
| self.decoder = Decoder() |
|
|
| def set_mode(self, mode): |
| self.mode = mode |
|
|
| def forward(self, x, gt): |
| if self.mode == 'train': |
| preds = self._forward(x, gt) |
| else: |
| with torch.no_grad(): |
| preds = self._forward(x, gt) |
|
|
| return preds |
|
|
| def featextract(self, x): |
| x1 = self.backbone.conv1(x) |
| x2 = self.backbone.conv2(x1) |
| x3 = self.backbone.conv3(x2) |
| x4 = self.backbone.conv4(x3) |
| x5 = self.backbone.conv5(x4) |
| return x5, x4, x3, x2, x1 |
|
|
| def _forward(self, x, gt): |
| [B, _, H, W] = x.size() |
| x5, x4, x3, x2, x1 = self.featextract(x) |
| feat, proto, weighted_x5, cormap = self.fusion(x5) |
| feataug = self.aug(weighted_x5) |
| preds = self.decoder(feataug, x4, x3, x2, x1, H, W) |
| if self.training: |
| gt = F.interpolate(gt, size=weighted_x5.size()[2:], mode='bilinear', align_corners=False) |
| feat_pos, proto_pos, weighted_x5_pos, cormap_pos = self.fusion(x5 * gt) |
| feat_neg, proto_neg, weighted_x5_neg, cormap_neg = self.fusion(x5*(1-gt)) |
| return preds, proto, proto_pos, proto_neg |
| return preds |
|
|
|
|
| class DCFM(nn.Module): |
| def __init__(self, mode='train'): |
| super(DCFM, self).__init__() |
| set_seed(123) |
| self.dcfmnet = DCFMNet() |
| self.mode = mode |
|
|
| def set_mode(self, mode): |
| self.mode = mode |
| self.dcfmnet.set_mode(self.mode) |
|
|
| def forward(self, x, gt): |
| |
| preds = self.dcfmnet(x, gt) |
| return preds |
|
|
|
|