Spaces:
Running
Running
| # This code clone from https://github.com/ooooverflow/BiSeNet | |
| import torch.nn as nn | |
| import torch | |
| import torch.nn.functional as F | |
| from . import model_util | |
| import warnings | |
| warnings.filterwarnings(action='ignore') | |
| def flatten(tensor): | |
| """Flattens a given tensor such that the channel axis is first. | |
| The shapes are transformed as follows: | |
| (N, C, D, H, W) -> (C, N * D * H * W) | |
| """ | |
| C = tensor.size(1) | |
| # new axis order | |
| axis_order = (1, 0) + tuple(range(2, tensor.dim())) | |
| # Transpose: (N, C, D, H, W) -> (C, N, D, H, W) | |
| transposed = tensor.permute(axis_order) | |
| # Flatten: (C, N, D, H, W) -> (C, N * D * H * W) | |
| return transposed.contiguous().view(C, -1) | |
| class DiceLoss(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.epsilon = 1e-5 | |
| def forward(self, output, target): | |
| assert output.size() == target.size(), "'input' and 'target' must have the same shape" | |
| output = F.softmax(output, dim=1) | |
| output = flatten(output) | |
| target = flatten(target) | |
| # intersect = (output * target).sum(-1).sum() + self.epsilon | |
| # denominator = ((output + target).sum(-1)).sum() + self.epsilon | |
| intersect = (output * target).sum(-1) | |
| denominator = (output + target).sum(-1) | |
| dice = intersect / denominator | |
| dice = torch.mean(dice) | |
| return 1 - dice | |
| # return 1 - 2. * intersect / denominator | |
| class resnet18(torch.nn.Module): | |
| def __init__(self, pretrained=True): | |
| super().__init__() | |
| self.features = model_util.resnet18(pretrained=pretrained) | |
| self.conv1 = self.features.conv1 | |
| self.bn1 = self.features.bn1 | |
| self.relu = self.features.relu | |
| self.maxpool1 = self.features.maxpool | |
| self.layer1 = self.features.layer1 | |
| self.layer2 = self.features.layer2 | |
| self.layer3 = self.features.layer3 | |
| self.layer4 = self.features.layer4 | |
| def forward(self, input): | |
| x = self.conv1(input) | |
| x = self.relu(self.bn1(x)) | |
| x = self.maxpool1(x) | |
| feature1 = self.layer1(x) # 1 / 4 | |
| feature2 = self.layer2(feature1) # 1 / 8 | |
| feature3 = self.layer3(feature2) # 1 / 16 | |
| feature4 = self.layer4(feature3) # 1 / 32 | |
| # global average pooling to build tail | |
| tail = torch.mean(feature4, 3, keepdim=True) | |
| tail = torch.mean(tail, 2, keepdim=True) | |
| return feature3, feature4, tail | |
| class resnet101(torch.nn.Module): | |
| def __init__(self, pretrained=True): | |
| super().__init__() | |
| self.features = model_util.resnet101(pretrained=pretrained) | |
| self.conv1 = self.features.conv1 | |
| self.bn1 = self.features.bn1 | |
| self.relu = self.features.relu | |
| self.maxpool1 = self.features.maxpool | |
| self.layer1 = self.features.layer1 | |
| self.layer2 = self.features.layer2 | |
| self.layer3 = self.features.layer3 | |
| self.layer4 = self.features.layer4 | |
| def forward(self, input): | |
| x = self.conv1(input) | |
| x = self.relu(self.bn1(x)) | |
| x = self.maxpool1(x) | |
| feature1 = self.layer1(x) # 1 / 4 | |
| feature2 = self.layer2(feature1) # 1 / 8 | |
| feature3 = self.layer3(feature2) # 1 / 16 | |
| feature4 = self.layer4(feature3) # 1 / 32 | |
| # global average pooling to build tail | |
| tail = torch.mean(feature4, 3, keepdim=True) | |
| tail = torch.mean(tail, 2, keepdim=True) | |
| return feature3, feature4, tail | |
| def build_contextpath(name,pretrained): | |
| model = { | |
| 'resnet18': resnet18(pretrained=pretrained), | |
| 'resnet101': resnet101(pretrained=pretrained) | |
| } | |
| return model[name] | |
| class ConvBlock(torch.nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size=3, stride=2,padding=1): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) | |
| self.bn = nn.BatchNorm2d(out_channels) | |
| self.relu = nn.ReLU() | |
| def forward(self, input): | |
| x = self.conv1(input) | |
| return self.relu(self.bn(x)) | |
| class Spatial_path(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.convblock1 = ConvBlock(in_channels=3, out_channels=64) | |
| self.convblock2 = ConvBlock(in_channels=64, out_channels=128) | |
| self.convblock3 = ConvBlock(in_channels=128, out_channels=256) | |
| def forward(self, input): | |
| x = self.convblock1(input) | |
| x = self.convblock2(x) | |
| x = self.convblock3(x) | |
| return x | |
| class AttentionRefinementModule(torch.nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) | |
| self.bn = nn.BatchNorm2d(out_channels) | |
| self.sigmoid = nn.Sigmoid() | |
| self.in_channels = in_channels | |
| self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) | |
| def forward(self, input): | |
| # global average pooling | |
| x = self.avgpool(input) | |
| assert self.in_channels == x.size(1), 'in_channels and out_channels should all be {}'.format(x.size(1)) | |
| x = self.conv(x) | |
| # x = self.sigmoid(self.bn(x)) | |
| x = self.sigmoid(x) | |
| # channels of input and x should be same | |
| x = torch.mul(input, x) | |
| return x | |
| class FeatureFusionModule(torch.nn.Module): | |
| def __init__(self, num_classes, in_channels): | |
| super().__init__() | |
| # self.in_channels = input_1.channels + input_2.channels | |
| # resnet101 3328 = 256(from context path) + 1024(from spatial path) + 2048(from spatial path) | |
| # resnet18 1024 = 256(from context path) + 256(from spatial path) + 512(from spatial path) | |
| self.in_channels = in_channels | |
| self.convblock = ConvBlock(in_channels=self.in_channels, out_channels=num_classes, stride=1) | |
| self.conv1 = nn.Conv2d(num_classes, num_classes, kernel_size=1) | |
| self.relu = nn.ReLU() | |
| self.conv2 = nn.Conv2d(num_classes, num_classes, kernel_size=1) | |
| self.sigmoid = nn.Sigmoid() | |
| self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) | |
| def forward(self, input_1, input_2): | |
| x = torch.cat((input_1, input_2), dim=1) | |
| assert self.in_channels == x.size(1), 'in_channels of ConvBlock should be {}'.format(x.size(1)) | |
| feature = self.convblock(x) | |
| x = self.avgpool(feature) | |
| x = self.relu(self.conv1(x)) | |
| x = self.sigmoid(self.conv2(x)) | |
| x = torch.mul(feature, x) | |
| x = torch.add(x, feature) | |
| return x | |
| class BiSeNet(torch.nn.Module): | |
| def __init__(self, num_classes, context_path, train_flag=True): | |
| super().__init__() | |
| # build spatial path | |
| self.saptial_path = Spatial_path() | |
| self.sigmoid = nn.Sigmoid() | |
| # build context path | |
| if train_flag: | |
| self.context_path = build_contextpath(name=context_path,pretrained=True) | |
| else: | |
| self.context_path = build_contextpath(name=context_path,pretrained=False) | |
| # build attention refinement module for resnet 101 | |
| if context_path == 'resnet101': | |
| self.attention_refinement_module1 = AttentionRefinementModule(1024, 1024) | |
| self.attention_refinement_module2 = AttentionRefinementModule(2048, 2048) | |
| # supervision block | |
| self.supervision1 = nn.Conv2d(in_channels=1024, out_channels=num_classes, kernel_size=1) | |
| self.supervision2 = nn.Conv2d(in_channels=2048, out_channels=num_classes, kernel_size=1) | |
| # build feature fusion module | |
| self.feature_fusion_module = FeatureFusionModule(num_classes, 3328) | |
| elif context_path == 'resnet18': | |
| # build attention refinement module for resnet 18 | |
| self.attention_refinement_module1 = AttentionRefinementModule(256, 256) | |
| self.attention_refinement_module2 = AttentionRefinementModule(512, 512) | |
| # supervision block | |
| self.supervision1 = nn.Conv2d(in_channels=256, out_channels=num_classes, kernel_size=1) | |
| self.supervision2 = nn.Conv2d(in_channels=512, out_channels=num_classes, kernel_size=1) | |
| # build feature fusion module | |
| self.feature_fusion_module = FeatureFusionModule(num_classes, 1024) | |
| else: | |
| print('Error: unspport context_path network \n') | |
| # build final convolution | |
| self.conv = nn.Conv2d(in_channels=num_classes, out_channels=num_classes, kernel_size=1) | |
| self.init_weight() | |
| self.mul_lr = [] | |
| self.mul_lr.append(self.saptial_path) | |
| self.mul_lr.append(self.attention_refinement_module1) | |
| self.mul_lr.append(self.attention_refinement_module2) | |
| self.mul_lr.append(self.supervision1) | |
| self.mul_lr.append(self.supervision2) | |
| self.mul_lr.append(self.feature_fusion_module) | |
| self.mul_lr.append(self.conv) | |
| def init_weight(self): | |
| for name, m in self.named_modules(): | |
| if 'context_path' not in name: | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') | |
| elif isinstance(m, nn.BatchNorm2d): | |
| m.eps = 1e-5 | |
| m.momentum = 0.1 | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, input): | |
| # output of spatial path | |
| sx = self.saptial_path(input) | |
| # output of context path | |
| cx1, cx2, tail = self.context_path(input) | |
| cx1 = self.attention_refinement_module1(cx1) | |
| cx2 = self.attention_refinement_module2(cx2) | |
| cx2 = torch.mul(cx2, tail) | |
| # upsampling | |
| cx1 = torch.nn.functional.interpolate(cx1, size=sx.size()[-2:], mode='bilinear') | |
| cx2 = torch.nn.functional.interpolate(cx2, size=sx.size()[-2:], mode='bilinear') | |
| cx = torch.cat((cx1, cx2), dim=1) | |
| if self.training == True: | |
| cx1_sup = self.supervision1(cx1) | |
| cx2_sup = self.supervision2(cx2) | |
| cx1_sup = torch.nn.functional.interpolate(cx1_sup, size=input.size()[-2:], mode='bilinear') | |
| cx2_sup = torch.nn.functional.interpolate(cx2_sup, size=input.size()[-2:], mode='bilinear') | |
| # output of feature fusion module | |
| result = self.feature_fusion_module(sx, cx) | |
| # upsampling | |
| result = torch.nn.functional.interpolate(result, scale_factor=8, mode='bilinear') | |
| result = self.conv(result) | |
| if self.training == True: | |
| return self.sigmoid(result), self.sigmoid(cx1_sup), self.sigmoid(cx2_sup) | |
| return self.sigmoid(result) |