Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import fvcore.nn.weight_init as weight_init | |
| from config import Config | |
| config = Config() | |
| class ResBlk(nn.Module): | |
| def __init__(self, channel_in=64, channel_out=64): | |
| super(ResBlk, self).__init__() | |
| self.conv_in = nn.Conv2d(channel_in, 64, 3, 1, 1) | |
| self.relu_in = nn.ReLU(inplace=True) | |
| self.conv_out = nn.Conv2d(64, channel_out, 3, 1, 1) | |
| if config.use_bn: | |
| self.bn_in = nn.BatchNorm2d(64) | |
| self.bn_out = nn.BatchNorm2d(channel_out) | |
| def forward(self, x): | |
| x = self.conv_in(x) | |
| if config.use_bn: | |
| x = self.bn_in(x) | |
| x = self.relu_in(x) | |
| x = self.conv_out(x) | |
| if config.use_bn: | |
| x = self.bn_out(x) | |
| return x | |
| class DSLayer(nn.Module): | |
| def __init__(self, channel_in=64, channel_out=1, activation_out='relu'): | |
| super(DSLayer, self).__init__() | |
| self.activation_out = activation_out | |
| self.conv1 = nn.Conv2d(channel_in, 64, kernel_size=3, stride=1, padding=1) | |
| self.relu1 = nn.ReLU(inplace=True) | |
| self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |
| self.relu2 = nn.ReLU(inplace=True) | |
| if activation_out: | |
| self.pred_conv = nn.Conv2d(64, channel_out, kernel_size=1, stride=1, padding=0) | |
| self.pred_relu = nn.ReLU(inplace=True) | |
| else: | |
| self.pred_conv = nn.Conv2d(64, channel_out, kernel_size=1, stride=1, padding=0) | |
| if config.use_bn: | |
| self.bn1 = nn.BatchNorm2d(64) | |
| self.bn2 = nn.BatchNorm2d(64) | |
| self.pred_bn = nn.BatchNorm2d(channel_out) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| if config.use_bn: | |
| x = self.bn1(x) | |
| x = self.relu1(x) | |
| x = self.conv2(x) | |
| if config.use_bn: | |
| x = self.bn2(x) | |
| x = self.relu2(x) | |
| x = self.pred_conv(x) | |
| if config.use_bn: | |
| x = self.pred_bn(x) | |
| if self.activation_out: | |
| x = self.pred_relu(x) | |
| return x | |
| class half_DSLayer(nn.Module): | |
| def __init__(self, channel_in=512): | |
| super(half_DSLayer, self).__init__() | |
| self.enlayer = nn.Sequential( | |
| nn.Conv2d(channel_in, int(channel_in//4), kernel_size=3, stride=1, padding=1), | |
| nn.ReLU(inplace=True) | |
| ) | |
| self.predlayer = nn.Sequential( | |
| nn.Conv2d(int(channel_in//4), 1, kernel_size=1, stride=1, padding=0), | |
| ) | |
| def forward(self, x): | |
| x = self.enlayer(x) | |
| x = self.predlayer(x) | |
| return x | |
| class CoAttLayer(nn.Module): | |
| def __init__(self, channel_in=512): | |
| super(CoAttLayer, self).__init__() | |
| self.all_attention = eval(Config().relation_module + '(channel_in)') | |
| self.conv_output = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0) | |
| self.conv_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0) | |
| self.fc_transform = nn.Linear(channel_in, channel_in) | |
| for layer in [self.conv_output, self.conv_transform, self.fc_transform]: | |
| weight_init.c2_msra_fill(layer) | |
| def forward(self, x5): | |
| if self.training: | |
| f_begin = 0 | |
| f_end = int(x5.shape[0] / 2) | |
| s_begin = f_end | |
| s_end = int(x5.shape[0]) | |
| x5_1 = x5[f_begin: f_end] | |
| x5_2 = x5[s_begin: s_end] | |
| x5_new_1 = self.all_attention(x5_1) | |
| x5_new_2 = self.all_attention(x5_2) | |
| x5_1_proto = torch.mean(x5_new_1, (0, 2, 3), True).view(1, -1) | |
| x5_1_proto = x5_1_proto.unsqueeze(-1).unsqueeze(-1) # 1, C, 1, 1 | |
| x5_2_proto = torch.mean(x5_new_2, (0, 2, 3), True).view(1, -1) | |
| x5_2_proto = x5_2_proto.unsqueeze(-1).unsqueeze(-1) # 1, C, 1, 1 | |
| x5_11 = x5_1 * x5_1_proto | |
| x5_22 = x5_2 * x5_2_proto | |
| weighted_x5 = torch.cat([x5_11, x5_22], dim=0) | |
| x5_12 = x5_1 * x5_2_proto | |
| x5_21 = x5_2 * x5_1_proto | |
| neg_x5 = torch.cat([x5_12, x5_21], dim=0) | |
| else: | |
| x5_new = self.all_attention(x5) | |
| x5_proto = torch.mean(x5_new, (0, 2, 3), True).view(1, -1) | |
| x5_proto = x5_proto.unsqueeze(-1).unsqueeze(-1) # 1, C, 1, 1 | |
| weighted_x5 = x5 * x5_proto #* cweight | |
| neg_x5 = None | |
| return weighted_x5, neg_x5 | |
| class ICE(nn.Module): | |
| # The Integrity Channel Enhancement (ICE) module | |
| # _X means in X-th column | |
| def __init__(self, channel_in=512): | |
| super(ICE, self).__init__() | |
| self.conv_1 = nn.Conv2d(channel_in, channel_in, 3, 1, 1) | |
| self.conv_2 = nn.Conv1d(channel_in, channel_in, 3, 1, 1) | |
| self.conv_3 = nn.Conv2d(channel_in*3, channel_in, 3, 1, 1) | |
| self.fc_2 = nn.Linear(channel_in, channel_in) | |
| self.fc_3 = nn.Linear(channel_in, channel_in) | |
| def forward(self, x): | |
| x_1, x_2, x_3 = x, x, x | |
| x_1 = x_1 * x_2 * x_3 | |
| x_2 = x_1 + x_2 + x_3 | |
| x_3 = torch.cat((x_1, x_2, x_3), dim=1) | |
| V = self.conv_1(x_1) | |
| bs, c, h, w = x_2.shape | |
| K = self.conv_2(x_2.view(bs, c, h*w)) | |
| Q_prime = self.conv_3(x_3) | |
| Q_prime = torch.norm(Q_prime, dim=(-2, -1)).view(bs, c, 1, 1) | |
| Q_prime = Q_prime.view(bs, -1) | |
| Q_prime = self.fc_3(Q_prime) | |
| Q_prime = torch.softmax(Q_prime, dim=-1) | |
| Q_prime = Q_prime.unsqueeze(1) | |
| Q = torch.matmul(Q_prime, K) | |
| x_2 = torch.nn.functional.cosine_similarity(K, Q, dim=-1) | |
| x_2 = torch.sigmoid(x_2) | |
| x_2 = self.fc_2(x_2) | |
| x_2 = x_2.unsqueeze(-1).unsqueeze(-1) | |
| x_1 = V * x_2 + V | |
| return x_1 | |
| class GAM(nn.Module): | |
| def __init__(self, channel_in=512): | |
| super(GAM, self).__init__() | |
| self.query_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0) | |
| self.key_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0) | |
| self.scale = 1.0 / (channel_in ** 0.5) | |
| self.conv6 = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0) | |
| for layer in [self.query_transform, self.key_transform, self.conv6]: | |
| weight_init.c2_msra_fill(layer) | |
| def forward(self, x5): | |
| # x: B,C,H,W | |
| # x_query: B,C,HW | |
| B, C, H5, W5 = x5.size() | |
| x_query = self.query_transform(x5).view(B, C, -1) | |
| # x_query: B,HW,C | |
| x_query = torch.transpose(x_query, 1, 2).contiguous().view(-1, C) # BHW, C | |
| # x_key: B,C,HW | |
| x_key = self.key_transform(x5).view(B, C, -1) | |
| x_key = torch.transpose(x_key, 0, 1).contiguous().view(C, -1) # C, BHW | |
| # W = Q^T K: B,HW,HW | |
| x_w = torch.matmul(x_query, x_key) #* self.scale # BHW, BHW | |
| x_w = x_w.view(B*H5*W5, B, H5*W5) | |
| x_w = torch.max(x_w, -1).values # BHW, B | |
| x_w = x_w.mean(-1) | |
| #x_w = torch.mean(x_w, -1).values # BHW | |
| x_w = x_w.view(B, -1) * self.scale # B, HW | |
| x_w = F.softmax(x_w, dim=-1) # B, HW | |
| x_w = x_w.view(B, H5, W5).unsqueeze(1) # B, 1, H, W | |
| x5 = x5 * x_w | |
| x5 = self.conv6(x5) | |
| return x5 | |
| class MHA(nn.Module): | |
| ''' | |
| Scaled dot-product attention | |
| ''' | |
| def __init__(self, d_model=512, d_k=512, d_v=512, h=8, dropout=.1, channel_in=512): | |
| ''' | |
| :param d_model: Output dimensionality of the model | |
| :param d_k: Dimensionality of queries and keys | |
| :param d_v: Dimensionality of values | |
| :param h: Number of heads | |
| ''' | |
| super(MHA, self).__init__() | |
| self.query_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0) | |
| self.key_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0) | |
| self.value_transform = nn.Conv2d(channel_in, channel_in, kernel_size=1, stride=1, padding=0) | |
| self.fc_q = nn.Linear(d_model, h * d_k) | |
| self.fc_k = nn.Linear(d_model, h * d_k) | |
| self.fc_v = nn.Linear(d_model, h * d_v) | |
| self.fc_o = nn.Linear(h * d_v, d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| self.d_model = d_model | |
| self.d_k = d_k | |
| self.d_v = d_v | |
| self.h = h | |
| self.init_weights() | |
| def init_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight, mode='fan_out') | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.BatchNorm2d): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.Linear): | |
| nn.init.normal_(m.weight, std=0.001) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x, attention_mask=None, attention_weights=None): | |
| ''' | |
| Computes | |
| :param queries: Queries (b_s, nq, d_model) | |
| :param keys: Keys (b_s, nk, d_model) | |
| :param values: Values (b_s, nk, d_model) | |
| :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking. | |
| :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk). | |
| :return: | |
| ''' | |
| B, C, H, W = x.size() | |
| queries = self.query_transform(x).view(B, -1, C) | |
| keys = self.query_transform(x).view(B, -1, C) | |
| values = self.query_transform(x).view(B, -1, C) | |
| b_s, nq = queries.shape[:2] | |
| nk = keys.shape[1] | |
| q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) | |
| k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) | |
| v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) | |
| att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) | |
| if attention_weights is not None: | |
| att = att * attention_weights | |
| if attention_mask is not None: | |
| att = att.masked_fill(attention_mask, -np.inf) | |
| att = torch.softmax(att, -1) | |
| att = self.dropout(att) | |
| out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) | |
| out = self.fc_o(out).view(B, C, H, W) # (b_s, nq, d_model) | |
| return out | |
| class NonLocal(nn.Module): | |
| def __init__(self, channel_in=512, inter_channels=None, dimension=2, sub_sample=True, bn_layer=True): | |
| super(NonLocal, self).__init__() | |
| assert dimension in [1, 2, 3] | |
| self.dimension = dimension | |
| self.sub_sample = sub_sample | |
| self.channel_in = channel_in | |
| self.inter_channels = inter_channels | |
| if self.inter_channels is None: | |
| self.inter_channels = channel_in // 2 | |
| if self.inter_channels == 0: | |
| self.inter_channels = 1 | |
| self.g = nn.Conv2d(self.channel_in, self.inter_channels, 1, 1, 0) | |
| if bn_layer: | |
| self.W = nn.Sequential( | |
| nn.Conv2d(self.inter_channels, self.channel_in, kernel_size=1, stride=1, padding=0), | |
| nn.BatchNorm2d(self.channel_in) | |
| ) | |
| nn.init.constant_(self.W[1].weight, 0) | |
| nn.init.constant_(self.W[1].bias, 0) | |
| else: | |
| self.W = nn.Conv2d(self.inter_channels, self.channel_in, kernel_size=1, stride=1, padding=0) | |
| nn.init.constant_(self.W.weight, 0) | |
| nn.init.constant_(self.W.bias, 0) | |
| self.theta = nn.Conv2d(self.channel_in, self.inter_channels, kernel_size=1, stride=1, padding=0) | |
| self.phi = nn.Conv2d(self.channel_in, self.inter_channels, kernel_size=1, stride=1, padding=0) | |
| if sub_sample: | |
| self.g = nn.Sequential(self.g, nn.MaxPool2d(kernel_size=(2, 2))) | |
| self.phi = nn.Sequential(self.phi, nn.MaxPool2d(kernel_size=(2, 2))) | |
| def forward(self, x, return_nl_map=False): | |
| """ | |
| :param x: (b, c, t, h, w) | |
| :param return_nl_map: if True return z, nl_map, else only return z. | |
| :return: | |
| """ | |
| batch_size = x.size(0) | |
| g_x = self.g(x).view(batch_size, self.inter_channels, -1) | |
| g_x = g_x.permute(0, 2, 1) | |
| theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) | |
| theta_x = theta_x.permute(0, 2, 1) | |
| phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) | |
| f = torch.matmul(theta_x, phi_x) | |
| f_div_C = F.softmax(f, dim=-1) | |
| y = torch.matmul(f_div_C, g_x) | |
| y = y.permute(0, 2, 1).contiguous() | |
| y = y.view(batch_size, self.inter_channels, *x.size()[2:]) | |
| W_y = self.W(y) | |
| z = W_y + x | |
| if return_nl_map: | |
| return z, f_div_C | |
| return z | |
| class DBHead(nn.Module): | |
| def __init__(self, channel_in=32, channel_out=1, k=config.db_k): | |
| super().__init__() | |
| self.k = k | |
| self.binarize = nn.Sequential( | |
| nn.Conv2d(channel_in, channel_in, 3, 1, 1), | |
| *[nn.BatchNorm2d(channel_in), nn.ReLU(inplace=True)] if config.use_bn else nn.ReLU(inplace=True), | |
| nn.Conv2d(channel_in, channel_in, 3, 1, 1), | |
| *[nn.BatchNorm2d(channel_in), nn.ReLU(inplace=True)] if config.use_bn else nn.ReLU(inplace=True), | |
| nn.Conv2d(channel_in, channel_out, 1, 1, 0), | |
| nn.Sigmoid() | |
| ) | |
| self.thresh = nn.Sequential( | |
| nn.Conv2d(channel_in, channel_in, 3, padding=1), | |
| *[nn.BatchNorm2d(channel_in), nn.ReLU(inplace=True)] if config.use_bn else nn.ReLU(inplace=True), | |
| nn.Conv2d(channel_in, channel_in, 3, 1, 1), | |
| *[nn.BatchNorm2d(channel_in), nn.ReLU(inplace=True)] if config.use_bn else nn.ReLU(inplace=True), | |
| nn.Conv2d(channel_in, channel_out, 1, 1, 0), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| shrink_maps = self.binarize(x) | |
| threshold_maps = self.thresh(x) | |
| binary_maps = self.step_function(shrink_maps, threshold_maps) | |
| return binary_maps | |
| def step_function(self, x, y): | |
| if config.db_k_alpha != 1: | |
| z = x - y | |
| mask_neg_inv = 1 - 2 * (z < 0) | |
| a = torch.exp(-self.k * (torch.pow(z * mask_neg_inv + 1e-16, 1/config.k_alpha) * mask_neg_inv)) | |
| else: | |
| a = torch.exp(-self.k * (x - y)) | |
| if torch.isinf(a).any(): | |
| a = torch.exp(-50 * (x - y)) | |
| return torch.reciprocal(1 + a) | |
| class RefUnet(nn.Module): | |
| # Refinement | |
| def __init__(self, in_ch, inc_ch): | |
| super(RefUnet, self).__init__() | |
| self.conv0 = nn.Conv2d(in_ch, inc_ch, 3, padding=1) | |
| self.conv1 = nn.Conv2d(inc_ch, 64, 3, padding=1) | |
| if config.use_bn: | |
| self.bn1 = nn.BatchNorm2d(64) | |
| self.relu1 = nn.ReLU(inplace=True) | |
| self.pool1 = nn.MaxPool2d(2, 2, ceil_mode=True) | |
| self.conv2 = nn.Conv2d(64, 64, 3, padding=1) | |
| if config.use_bn: | |
| self.bn2 = nn.BatchNorm2d(64) | |
| self.relu2 = nn.ReLU(inplace=True) | |
| self.pool2 = nn.MaxPool2d(2, 2, ceil_mode=True) | |
| self.conv3 = nn.Conv2d(64, 64, 3, padding=1) | |
| if config.use_bn: | |
| self.bn3 = nn.BatchNorm2d(64) | |
| self.relu3 = nn.ReLU(inplace=True) | |
| self.pool3 = nn.MaxPool2d(2, 2, ceil_mode=True) | |
| self.conv4 = nn.Conv2d(64, 64, 3, padding=1) | |
| if config.use_bn: | |
| self.bn4 = nn.BatchNorm2d(64) | |
| self.relu4 = nn.ReLU(inplace=True) | |
| self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True) | |
| ##### | |
| self.conv5 = nn.Conv2d(64, 64, 3, padding=1) | |
| if config.use_bn: | |
| self.bn5 = nn.BatchNorm2d(64) | |
| self.relu5 = nn.ReLU(inplace=True) | |
| ##### | |
| self.conv_d4 = nn.Conv2d(128, 64, 3, padding=1) | |
| if config.use_bn: | |
| self.bn_d4 = nn.BatchNorm2d(64) | |
| self.relu_d4 = nn.ReLU(inplace=True) | |
| self.conv_d3 = nn.Conv2d(128, 64, 3, padding=1) | |
| if config.use_bn: | |
| self.bn_d3 = nn.BatchNorm2d(64) | |
| self.relu_d3 = nn.ReLU(inplace=True) | |
| self.conv_d2 = nn.Conv2d(128, 64, 3, padding=1) | |
| if config.use_bn: | |
| self.bn_d2 = nn.BatchNorm2d(64) | |
| self.relu_d2 = nn.ReLU(inplace=True) | |
| self.conv_d1 = nn.Conv2d(128, 64, 3, padding=1) | |
| if config.use_bn: | |
| self.bn_d1 = nn.BatchNorm2d(64) | |
| self.relu_d1 = nn.ReLU(inplace=True) | |
| self.conv_d0 = nn.Conv2d(64, 1, 3, padding=1) | |
| self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) | |
| if config.db_output_refiner: | |
| self.db_output_refiner = DBHead(64) | |
| def forward(self, x): | |
| hx = x | |
| hx = self.conv1(self.conv0(hx)) | |
| if config.use_bn: | |
| hx = self.bn1(hx) | |
| hx1 = self.relu1(hx) | |
| hx = self.conv2(self.pool1(hx1)) | |
| if config.use_bn: | |
| hx = self.bn2(hx) | |
| hx2 = self.relu2(hx) | |
| hx = self.conv3(self.pool2(hx2)) | |
| if config.use_bn: | |
| hx = self.bn3(hx) | |
| hx3 = self.relu3(hx) | |
| hx = self.conv4(self.pool3(hx3)) | |
| if config.use_bn: | |
| hx = self.bn4(hx) | |
| hx4 = self.relu4(hx) | |
| hx = self.conv5(self.pool4(hx4)) | |
| if config.use_bn: | |
| hx = self.bn5(hx) | |
| hx5 = self.relu5(hx) | |
| hx = self.upscore2(hx5) | |
| d4 = self.conv_d4(torch.cat((hx, hx4), 1)) | |
| if config.use_bn: | |
| d4 = self.bn_d4(d4) | |
| d4 = self.relu_d4(d4) | |
| hx = self.upscore2(d4) | |
| d3 = self.conv_d3(torch.cat((hx, hx3), 1)) | |
| if config.use_bn: | |
| d3 = self.bn_d3(d3) | |
| d3 = self.relu_d3(d3) | |
| hx = self.upscore2(d3) | |
| d2 = self.conv_d2(torch.cat((hx, hx2), 1)) | |
| if config.use_bn: | |
| d2 = self.bn_d2(d2) | |
| d2 = self.relu_d2(d2) | |
| hx = self.upscore2(d2) | |
| d1 = self.conv_d1(torch.cat((hx, hx1), 1)) | |
| if config.use_bn: | |
| d1 = self.bn_d1(d1) | |
| d1 = self.relu_d1(d1) | |
| if config.db_output_refiner: | |
| x = self.db_output_refiner(d1) | |
| else: | |
| residual = self.conv_d0(d1) | |
| x = x + residual | |
| return x | |