Dojofdd's picture
Add config.json and model card to enable Inference API
cca733a verified
# Paste the full DANet architecture code here... (same as before)
import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision import models
class PositionAttentionModule(nn.Module):
def __init__(self, in_dim):
super(PositionAttentionModule, self).__init__(); self.chanel_in = in_dim
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1)); self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
B, C, H, W = x.size()
proj_query = self.query_conv(x).view(B, -1, H * W).permute(0, 2, 1)
proj_key = self.key_conv(x).view(B, -1, H * W)
energy = torch.bmm(proj_query, proj_key)
attention_map = self.softmax(energy)
proj_value = self.value_conv(x).view(B, -1, H * W)
out = torch.bmm(proj_value, attention_map.permute(0, 2, 1))
out = out.view(B, C, H, W); out = self.gamma * out + x
return out
class ChannelAttentionModule(nn.Module):
def __init__(self):
super(ChannelAttentionModule, self).__init__()
self.beta = nn.Parameter(torch.zeros(1)); self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
B, C, H, W = x.size()
proj_query = x.view(B, C, -1); proj_key = x.view(B, C, -1).permute(0, 2, 1)
energy = torch.bmm(proj_query, proj_key); attention_map = self.softmax(energy)
proj_value = x.view(B, C, -1); out = torch.bmm(attention_map, proj_value)
out = out.view(B, C, H, W); out = self.beta * out + x
return out
class DANetHead(nn.Module):
def __init__(self, in_channels, out_channels):
super(DANetHead, self).__init__()
inter_channels = in_channels // 4
self.conv5a = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU())
self.conv5c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU())
self.sa = PositionAttentionModule(inter_channels); self.sc = ChannelAttentionModule()
self.conv51 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU())
self.conv52 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU())
self.conv8 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(inter_channels, out_channels, 1))
def forward(self, x):
feat_sa = self.conv5a(x); sa_feat = self.sa(feat_sa); sa_conv = self.conv51(sa_feat)
feat_sc = self.conv5c(x); sc_feat = self.sc(feat_sc); sc_conv = self.conv52(sc_feat)
feat_sum = sa_conv + sc_conv; s_out = self.conv8(feat_sum)
return s_out
class DANet(nn.Module):
def __init__(self, num_classes=2, backbone='resnet50', pretrained_base=False, aux=False):
super(DANet, self).__init__(); self.aux = aux
weights = None
resnet = models.resnet50(weights=weights); backbone_out_channels = 2048
resnet.layer3[0].conv2.stride = (1, 1); resnet.layer3[0].downsample[0].stride = (1, 1)
resnet.layer4[0].conv2.stride = (1, 1); resnet.layer4[0].downsample[0].stride = (1, 1)
for i in range(len(resnet.layer4)): resnet.layer4[i].conv2.dilation = (2, 2); resnet.layer4[i].conv2.padding = (2, 2)
self.conv1=resnet.conv1; self.bn1=resnet.bn1; self.relu=resnet.relu; self.maxpool=resnet.maxpool
self.layer1=resnet.layer1; self.layer2=resnet.layer2; self.layer3=resnet.layer3; self.layer4=resnet.layer4
self.head = DANetHead(backbone_out_channels, num_classes)
def forward(self, x):
imsize = x.size()[2:]; x = self.conv1(x); x = self.bn1(x); x = self.relu(x); x = self.maxpool(x)
x = self.layer1(x); x = self.layer2(x); c3 = self.layer3(x); c4 = self.layer4(c3)
main_out = self.head(c4); main_out = F.interpolate(main_out, size=imsize, mode='bilinear', align_corners=True)
return main_out