Initial upload of DANet fracture segmentation model
Browse files- model.py +68 -0
- pytorch_model.bin +3 -0
- requirements.txt +4 -0
model.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
from torchvision import models
|
| 6 |
+
|
| 7 |
+
class PositionAttentionModule(nn.Module):
|
| 8 |
+
def __init__(self, in_dim):
|
| 9 |
+
super(PositionAttentionModule, self).__init__(); self.chanel_in = in_dim
|
| 10 |
+
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
| 11 |
+
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
| 12 |
+
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
| 13 |
+
self.gamma = nn.Parameter(torch.zeros(1)); self.softmax = nn.Softmax(dim=-1)
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
B, C, H, W = x.size()
|
| 16 |
+
proj_query = self.query_conv(x).view(B, -1, H * W).permute(0, 2, 1)
|
| 17 |
+
proj_key = self.key_conv(x).view(B, -1, H * W)
|
| 18 |
+
energy = torch.bmm(proj_query, proj_key)
|
| 19 |
+
attention_map = self.softmax(energy)
|
| 20 |
+
proj_value = self.value_conv(x).view(B, -1, H * W)
|
| 21 |
+
out = torch.bmm(proj_value, attention_map.permute(0, 2, 1))
|
| 22 |
+
out = out.view(B, C, H, W); out = self.gamma * out + x
|
| 23 |
+
return out
|
| 24 |
+
|
| 25 |
+
class ChannelAttentionModule(nn.Module):
|
| 26 |
+
def __init__(self):
|
| 27 |
+
super(ChannelAttentionModule, self).__init__()
|
| 28 |
+
self.beta = nn.Parameter(torch.zeros(1)); self.softmax = nn.Softmax(dim=-1)
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
B, C, H, W = x.size()
|
| 31 |
+
proj_query = x.view(B, C, -1); proj_key = x.view(B, C, -1).permute(0, 2, 1)
|
| 32 |
+
energy = torch.bmm(proj_query, proj_key); attention_map = self.softmax(energy)
|
| 33 |
+
proj_value = x.view(B, C, -1); out = torch.bmm(attention_map, proj_value)
|
| 34 |
+
out = out.view(B, C, H, W); out = self.beta * out + x
|
| 35 |
+
return out
|
| 36 |
+
|
| 37 |
+
class DANetHead(nn.Module):
|
| 38 |
+
def __init__(self, in_channels, out_channels):
|
| 39 |
+
super(DANetHead, self).__init__()
|
| 40 |
+
inter_channels = in_channels // 4
|
| 41 |
+
self.conv5a = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU())
|
| 42 |
+
self.conv5c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU())
|
| 43 |
+
self.sa = PositionAttentionModule(inter_channels); self.sc = ChannelAttentionModule()
|
| 44 |
+
self.conv51 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU())
|
| 45 |
+
self.conv52 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU())
|
| 46 |
+
self.conv8 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(inter_channels, out_channels, 1))
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
feat_sa = self.conv5a(x); sa_feat = self.sa(feat_sa); sa_conv = self.conv51(sa_feat)
|
| 49 |
+
feat_sc = self.conv5c(x); sc_feat = self.sc(feat_sc); sc_conv = self.conv52(sc_feat)
|
| 50 |
+
feat_sum = sa_conv + sc_conv; s_out = self.conv8(feat_sum)
|
| 51 |
+
return s_out
|
| 52 |
+
|
| 53 |
+
class DANet(nn.Module):
|
| 54 |
+
def __init__(self, num_classes=2, backbone='resnet50', pretrained_base=False, aux=False):
|
| 55 |
+
super(DANet, self).__init__(); self.aux = aux
|
| 56 |
+
weights = None
|
| 57 |
+
resnet = models.resnet50(weights=weights); backbone_out_channels = 2048
|
| 58 |
+
resnet.layer3[0].conv2.stride = (1, 1); resnet.layer3[0].downsample[0].stride = (1, 1)
|
| 59 |
+
resnet.layer4[0].conv2.stride = (1, 1); resnet.layer4[0].downsample[0].stride = (1, 1)
|
| 60 |
+
for i in range(len(resnet.layer4)): resnet.layer4[i].conv2.dilation = (2, 2); resnet.layer4[i].conv2.padding = (2, 2)
|
| 61 |
+
self.conv1=resnet.conv1; self.bn1=resnet.bn1; self.relu=resnet.relu; self.maxpool=resnet.maxpool
|
| 62 |
+
self.layer1=resnet.layer1; self.layer2=resnet.layer2; self.layer3=resnet.layer3; self.layer4=resnet.layer4
|
| 63 |
+
self.head = DANetHead(backbone_out_channels, num_classes)
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
imsize = x.size()[2:]; x = self.conv1(x); x = self.bn1(x); x = self.relu(x); x = self.maxpool(x)
|
| 66 |
+
x = self.layer1(x); x = self.layer2(x); c3 = self.layer3(x); c4 = self.layer4(c3)
|
| 67 |
+
main_out = self.head(c4); main_out = F.interpolate(main_out, size=imsize, mode='bilinear', align_corners=True)
|
| 68 |
+
return main_out
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5aa56187075bea9a05e6f21130a58e0bf7249642764c13d80685a4ebe44ca399
|
| 3 |
+
size 208980802
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
numpy
|
| 4 |
+
Pillow
|