salmasoma
Set up inference-only HyperClinical Streamlit app with runtime HF asset download
278bf2b
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Residual Attention Network module - feature extractor in AVRA."""
import torch.nn as nn
import torch.nn.functional as F
class ResidualAttentionNet(nn.Module):
def __init__(self, z=1):
super(ResidualAttentionNet, self).__init__()
num_filters = [8, 16, 32, 64, 128]
k = 0
conv1 = nn.Sequential(
nn.Conv2d(z, num_filters[k], kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(num_filters[k]),
nn.ReLU(inplace=True)
)
maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
resblock1 = ResidualModule(num_filters[k], num_filters[k+1])
k += 1
attention_module1 = AttentionModule(num_filters[k], num_filters[k], stage=1)
resblock2 = ResidualModule(num_filters[k], num_filters[k+1], stride=2)
k += 1
attention_module2 = AttentionModule(num_filters[k], num_filters[k], stage=2)
resblock3 = ResidualModule(num_filters[k], num_filters[k+1], stride=2)
k += 1
attention_module3 = AttentionModule(num_filters[k], num_filters[k], stage=3)
resblock4 = ResidualModule(num_filters[k], num_filters[k+1], stride=2)
k += 1
resblock5 = ResidualModule(num_filters[k], num_filters[k])
resblock6 = ResidualModule(num_filters[k], num_filters[k])
avgpoolblock = nn.Sequential(
nn.BatchNorm2d(num_filters[k]),
nn.ReLU(inplace=True),
nn.AvgPool2d(kernel_size=3, stride=1)
)
self.features = nn.Sequential(
conv1, maxpool,
resblock1, attention_module1,
resblock2, attention_module2,
resblock3, attention_module3,
resblock4, resblock5, resblock6,
avgpoolblock
)
def forward(self, x):
out = self.features(x)
out = out.view(out.size(0), -1)
return out
class ResidualModule(nn.Module):
def __init__(self, inplanes, planes, stride=1):
super(ResidualModule, self).__init__()
planes_4 = int(planes/4)
self.inplanes = inplanes
self.planes = planes
self.stride = stride
self.bn1 = nn.BatchNorm2d(inplanes)
self.relu1 = nn.LeakyReLU()
self.conv1 = nn.Conv2d(inplanes, planes_4, kernel_size=1, stride=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes_4)
self.relu2 = nn.LeakyReLU()
self.conv2 = nn.Conv2d(planes_4, planes_4, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes_4)
self.relu3 = nn.LeakyReLU()
self.conv3 = nn.Conv2d(planes_4, planes, kernel_size=1, stride=1, bias=False)
self.conv4 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
self.downsample = (self.inplanes != self.planes) or (self.stride != 1)
def forward(self, x):
residual = x
out = self.bn1(x)
out1 = self.relu1(out)
out = self.conv1(out1)
out = self.bn2(out)
out = self.relu2(out)
out = self.conv2(out)
out = self.bn3(out)
out = self.relu3(out)
out = self.conv3(out)
if self.downsample:
residual = self.conv4(out1)
out += residual
return out
class AttentionModule(nn.Module):
def __init__(self, in_planes, out_planes, stage=1):
super(AttentionModule, self).__init__()
self.stage = stage
self.res1 = ResidualModule(in_planes, out_planes)
self.trunk_branch = nn.Sequential(
ResidualModule(in_planes, out_planes),
ResidualModule(in_planes, out_planes)
)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
if self.stage < 3:
self.block1 = ResidualModule(in_planes, out_planes)
self.skip1 = ResidualModule(in_planes, out_planes)
self.block5 = ResidualModule(in_planes, out_planes)
if self.stage == 1:
self.block2 = ResidualModule(in_planes, out_planes)
self.skip2 = ResidualModule(in_planes, out_planes)
self.block4 = ResidualModule(in_planes, out_planes)
self.block3 = nn.Sequential(
ResidualModule(in_planes, out_planes),
ResidualModule(in_planes, out_planes)
)
self.block_sigmoid = nn.Sequential(
nn.BatchNorm2d(out_planes),
nn.ReLU(inplace=True),
nn.Conv2d(out_planes, out_planes, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU(inplace=True),
nn.Conv2d(out_planes, out_planes, kernel_size=1, stride=1, bias=False),
nn.Sigmoid()
)
self.block6 = ResidualModule(in_planes, out_planes)
def upsample(self, x):
return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
def forward(self, x):
x = self.res1(x)
trunk_branch = self.trunk_branch(x)
if self.stage < 3:
x = self.maxpool(x)
x = self.block1(x)
skip1 = self.skip1(x)
if self.stage == 1:
x = self.maxpool(x)
x = self.block2(x)
skip2 = self.skip2(x)
x = self.maxpool(x)
x = self.block3(x)
if self.stage == 1:
x = self.upsample(x)
x = x + skip2
x = self.block4(x)
if self.stage < 3:
x = self.upsample(x)
x = x + skip1
x = self.block5(x)
x = self.upsample(x)
mask = self.block_sigmoid(x)
x = (1 + mask) * trunk_branch
x = x + trunk_branch
out_last = self.block6(x)
return out_last
def conv_block(in_planes, out_planes, bigblock, convxd, norm, pooling, fs=3, stride=1, relu=nn.ReLU):
if bigblock:
block = nn.Sequential(
convxd(in_planes, out_planes, fs, 1, int(fs/2)),
relu(True),
norm(out_planes),
convxd(out_planes, out_planes, fs, 1, int(fs/2)),
relu(True),
norm(out_planes),
convxd(out_planes, out_planes, fs, stride, int(fs/2)),
relu(True),
norm(out_planes),
pooling(2, 2)
)
else:
block = nn.Sequential(
convxd(in_planes, out_planes, fs, 1, int(fs/2)),
relu(True),
norm(out_planes),
convxd(out_planes, out_planes, fs, stride, int(fs/2)),
relu(True),
norm(out_planes),
pooling(2, 2)
)
return block