# This code clone from https://github.com/ooooverflow/BiSeNet import torch.nn as nn import torch import torch.nn.functional as F from . import model_util import warnings warnings.filterwarnings(action='ignore') def flatten(tensor): """Flattens a given tensor such that the channel axis is first. The shapes are transformed as follows: (N, C, D, H, W) -> (C, N * D * H * W) """ C = tensor.size(1) # new axis order axis_order = (1, 0) + tuple(range(2, tensor.dim())) # Transpose: (N, C, D, H, W) -> (C, N, D, H, W) transposed = tensor.permute(axis_order) # Flatten: (C, N, D, H, W) -> (C, N * D * H * W) return transposed.contiguous().view(C, -1) class DiceLoss(nn.Module): def __init__(self): super().__init__() self.epsilon = 1e-5 def forward(self, output, target): assert output.size() == target.size(), "'input' and 'target' must have the same shape" output = F.softmax(output, dim=1) output = flatten(output) target = flatten(target) # intersect = (output * target).sum(-1).sum() + self.epsilon # denominator = ((output + target).sum(-1)).sum() + self.epsilon intersect = (output * target).sum(-1) denominator = (output + target).sum(-1) dice = intersect / denominator dice = torch.mean(dice) return 1 - dice # return 1 - 2. * intersect / denominator class resnet18(torch.nn.Module): def __init__(self, pretrained=True): super().__init__() self.features = model_util.resnet18(pretrained=pretrained) self.conv1 = self.features.conv1 self.bn1 = self.features.bn1 self.relu = self.features.relu self.maxpool1 = self.features.maxpool self.layer1 = self.features.layer1 self.layer2 = self.features.layer2 self.layer3 = self.features.layer3 self.layer4 = self.features.layer4 def forward(self, input): x = self.conv1(input) x = self.relu(self.bn1(x)) x = self.maxpool1(x) feature1 = self.layer1(x) # 1 / 4 feature2 = self.layer2(feature1) # 1 / 8 feature3 = self.layer3(feature2) # 1 / 16 feature4 = self.layer4(feature3) # 1 / 32 # global average pooling to build tail tail = torch.mean(feature4, 3, keepdim=True) tail = torch.mean(tail, 2, keepdim=True) return feature3, feature4, tail class resnet101(torch.nn.Module): def __init__(self, pretrained=True): super().__init__() self.features = model_util.resnet101(pretrained=pretrained) self.conv1 = self.features.conv1 self.bn1 = self.features.bn1 self.relu = self.features.relu self.maxpool1 = self.features.maxpool self.layer1 = self.features.layer1 self.layer2 = self.features.layer2 self.layer3 = self.features.layer3 self.layer4 = self.features.layer4 def forward(self, input): x = self.conv1(input) x = self.relu(self.bn1(x)) x = self.maxpool1(x) feature1 = self.layer1(x) # 1 / 4 feature2 = self.layer2(feature1) # 1 / 8 feature3 = self.layer3(feature2) # 1 / 16 feature4 = self.layer4(feature3) # 1 / 32 # global average pooling to build tail tail = torch.mean(feature4, 3, keepdim=True) tail = torch.mean(tail, 2, keepdim=True) return feature3, feature4, tail def build_contextpath(name,pretrained): model = { 'resnet18': resnet18(pretrained=pretrained), 'resnet101': resnet101(pretrained=pretrained) } return model[name] class ConvBlock(torch.nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=2,padding=1): super().__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU() def forward(self, input): x = self.conv1(input) return self.relu(self.bn(x)) class Spatial_path(torch.nn.Module): def __init__(self): super().__init__() self.convblock1 = ConvBlock(in_channels=3, out_channels=64) self.convblock2 = ConvBlock(in_channels=64, out_channels=128) self.convblock3 = ConvBlock(in_channels=128, out_channels=256) def forward(self, input): x = self.convblock1(input) x = self.convblock2(x) x = self.convblock3(x) return x class AttentionRefinementModule(torch.nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) self.bn = nn.BatchNorm2d(out_channels) self.sigmoid = nn.Sigmoid() self.in_channels = in_channels self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) def forward(self, input): # global average pooling x = self.avgpool(input) assert self.in_channels == x.size(1), 'in_channels and out_channels should all be {}'.format(x.size(1)) x = self.conv(x) # x = self.sigmoid(self.bn(x)) x = self.sigmoid(x) # channels of input and x should be same x = torch.mul(input, x) return x class FeatureFusionModule(torch.nn.Module): def __init__(self, num_classes, in_channels): super().__init__() # self.in_channels = input_1.channels + input_2.channels # resnet101 3328 = 256(from context path) + 1024(from spatial path) + 2048(from spatial path) # resnet18 1024 = 256(from context path) + 256(from spatial path) + 512(from spatial path) self.in_channels = in_channels self.convblock = ConvBlock(in_channels=self.in_channels, out_channels=num_classes, stride=1) self.conv1 = nn.Conv2d(num_classes, num_classes, kernel_size=1) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(num_classes, num_classes, kernel_size=1) self.sigmoid = nn.Sigmoid() self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) def forward(self, input_1, input_2): x = torch.cat((input_1, input_2), dim=1) assert self.in_channels == x.size(1), 'in_channels of ConvBlock should be {}'.format(x.size(1)) feature = self.convblock(x) x = self.avgpool(feature) x = self.relu(self.conv1(x)) x = self.sigmoid(self.conv2(x)) x = torch.mul(feature, x) x = torch.add(x, feature) return x class BiSeNet(torch.nn.Module): def __init__(self, num_classes, context_path, train_flag=True): super().__init__() # build spatial path self.saptial_path = Spatial_path() self.sigmoid = nn.Sigmoid() # build context path if train_flag: self.context_path = build_contextpath(name=context_path,pretrained=True) else: self.context_path = build_contextpath(name=context_path,pretrained=False) # build attention refinement module for resnet 101 if context_path == 'resnet101': self.attention_refinement_module1 = AttentionRefinementModule(1024, 1024) self.attention_refinement_module2 = AttentionRefinementModule(2048, 2048) # supervision block self.supervision1 = nn.Conv2d(in_channels=1024, out_channels=num_classes, kernel_size=1) self.supervision2 = nn.Conv2d(in_channels=2048, out_channels=num_classes, kernel_size=1) # build feature fusion module self.feature_fusion_module = FeatureFusionModule(num_classes, 3328) elif context_path == 'resnet18': # build attention refinement module for resnet 18 self.attention_refinement_module1 = AttentionRefinementModule(256, 256) self.attention_refinement_module2 = AttentionRefinementModule(512, 512) # supervision block self.supervision1 = nn.Conv2d(in_channels=256, out_channels=num_classes, kernel_size=1) self.supervision2 = nn.Conv2d(in_channels=512, out_channels=num_classes, kernel_size=1) # build feature fusion module self.feature_fusion_module = FeatureFusionModule(num_classes, 1024) else: print('Error: unspport context_path network \n') # build final convolution self.conv = nn.Conv2d(in_channels=num_classes, out_channels=num_classes, kernel_size=1) self.init_weight() self.mul_lr = [] self.mul_lr.append(self.saptial_path) self.mul_lr.append(self.attention_refinement_module1) self.mul_lr.append(self.attention_refinement_module2) self.mul_lr.append(self.supervision1) self.mul_lr.append(self.supervision2) self.mul_lr.append(self.feature_fusion_module) self.mul_lr.append(self.conv) def init_weight(self): for name, m in self.named_modules(): if 'context_path' not in name: if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): m.eps = 1e-5 m.momentum = 0.1 nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, input): # output of spatial path sx = self.saptial_path(input) # output of context path cx1, cx2, tail = self.context_path(input) cx1 = self.attention_refinement_module1(cx1) cx2 = self.attention_refinement_module2(cx2) cx2 = torch.mul(cx2, tail) # upsampling cx1 = torch.nn.functional.interpolate(cx1, size=sx.size()[-2:], mode='bilinear') cx2 = torch.nn.functional.interpolate(cx2, size=sx.size()[-2:], mode='bilinear') cx = torch.cat((cx1, cx2), dim=1) if self.training == True: cx1_sup = self.supervision1(cx1) cx2_sup = self.supervision2(cx2) cx1_sup = torch.nn.functional.interpolate(cx1_sup, size=input.size()[-2:], mode='bilinear') cx2_sup = torch.nn.functional.interpolate(cx2_sup, size=input.size()[-2:], mode='bilinear') # output of feature fusion module result = self.feature_fusion_module(sx, cx) # upsampling result = torch.nn.functional.interpolate(result, scale_factor=8, mode='bilinear') result = self.conv(result) if self.training == True: return self.sigmoid(result), self.sigmoid(cx1_sup), self.sigmoid(cx2_sup) return self.sigmoid(result)