# Paste the full DANet architecture code here... (same as before) import torch import torch.nn as nn from torch.nn import functional as F from torchvision import models class PositionAttentionModule(nn.Module): def __init__(self, in_dim): super(PositionAttentionModule, self).__init__(); self.chanel_in = in_dim self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) self.gamma = nn.Parameter(torch.zeros(1)); self.softmax = nn.Softmax(dim=-1) def forward(self, x): B, C, H, W = x.size() proj_query = self.query_conv(x).view(B, -1, H * W).permute(0, 2, 1) proj_key = self.key_conv(x).view(B, -1, H * W) energy = torch.bmm(proj_query, proj_key) attention_map = self.softmax(energy) proj_value = self.value_conv(x).view(B, -1, H * W) out = torch.bmm(proj_value, attention_map.permute(0, 2, 1)) out = out.view(B, C, H, W); out = self.gamma * out + x return out class ChannelAttentionModule(nn.Module): def __init__(self): super(ChannelAttentionModule, self).__init__() self.beta = nn.Parameter(torch.zeros(1)); self.softmax = nn.Softmax(dim=-1) def forward(self, x): B, C, H, W = x.size() proj_query = x.view(B, C, -1); proj_key = x.view(B, C, -1).permute(0, 2, 1) energy = torch.bmm(proj_query, proj_key); attention_map = self.softmax(energy) proj_value = x.view(B, C, -1); out = torch.bmm(attention_map, proj_value) out = out.view(B, C, H, W); out = self.beta * out + x return out class DANetHead(nn.Module): def __init__(self, in_channels, out_channels): super(DANetHead, self).__init__() inter_channels = in_channels // 4 self.conv5a = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU()) self.conv5c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU()) self.sa = PositionAttentionModule(inter_channels); self.sc = ChannelAttentionModule() self.conv51 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU()) self.conv52 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU()) self.conv8 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(inter_channels, out_channels, 1)) def forward(self, x): feat_sa = self.conv5a(x); sa_feat = self.sa(feat_sa); sa_conv = self.conv51(sa_feat) feat_sc = self.conv5c(x); sc_feat = self.sc(feat_sc); sc_conv = self.conv52(sc_feat) feat_sum = sa_conv + sc_conv; s_out = self.conv8(feat_sum) return s_out class DANet(nn.Module): def __init__(self, num_classes=2, backbone='resnet50', pretrained_base=False, aux=False): super(DANet, self).__init__(); self.aux = aux weights = None resnet = models.resnet50(weights=weights); backbone_out_channels = 2048 resnet.layer3[0].conv2.stride = (1, 1); resnet.layer3[0].downsample[0].stride = (1, 1) resnet.layer4[0].conv2.stride = (1, 1); resnet.layer4[0].downsample[0].stride = (1, 1) for i in range(len(resnet.layer4)): resnet.layer4[i].conv2.dilation = (2, 2); resnet.layer4[i].conv2.padding = (2, 2) self.conv1=resnet.conv1; self.bn1=resnet.bn1; self.relu=resnet.relu; self.maxpool=resnet.maxpool self.layer1=resnet.layer1; self.layer2=resnet.layer2; self.layer3=resnet.layer3; self.layer4=resnet.layer4 self.head = DANetHead(backbone_out_channels, num_classes) def forward(self, x): imsize = x.size()[2:]; x = self.conv1(x); x = self.bn1(x); x = self.relu(x); x = self.maxpool(x) x = self.layer1(x); x = self.layer2(x); c3 = self.layer3(x); c4 = self.layer4(c3) main_out = self.head(c4); main_out = F.interpolate(main_out, size=imsize, mode='bilinear', align_corners=True) return main_out