salmasoma
Set up inference-only HyperClinical Streamlit app with runtime HF asset download
278bf2b
import numpy as np
import torch.nn as nn
import torch
from model.modules import ResidualAttentionNet, conv_block
class AVRA_rnn(nn.Module):
def __init__(self, input_dims):
super(AVRA_rnn, self).__init__()
self.features = ResidualAttentionNet()
input_size = [1, input_dims[0], input_dims[1]]
self.l = self.get_flat_fts(input_size, self.features)
self.hs = 256
self.rnn = nn.LSTM(
input_size=self.l,
hidden_size=self.hs,
num_layers=2,
batch_first=True,
bidirectional=False
)
f = 1
if self.rnn.bidirectional:
f = 2
self.linear = nn.Linear(self.hs * f, 1)
def get_flat_fts(self, in_size, fts):
f = fts(torch.Tensor(torch.ones(1, *in_size)))
return int(np.prod(f.size()[1:]))
def forward(self, x, return_r_out=False):
batch_size, timesteps, C, H, W = x.size()
c_in = x.view(batch_size * timesteps, C, H, W)
c_out = self.features(c_in)
r_in = c_out.view(batch_size, timesteps, -1)
r_out, (h_n, h_c) = self.rnn(r_in)
r_out_last = self.linear(r_out[:, -1, :])
if return_r_out:
return r_out_last, r_in.view(batch_size, -1)
return r_out_last
class VGG_bl(nn.Module):
def __init__(self, input_dims):
super().__init__()
x, y, z = input_dims
self.num_filters = [64, 128, 256, 512, 512]
self.convxd = nn.Conv2d
self.pooling = nn.MaxPool2d
self.norm = nn.BatchNorm2d
self.relu = nn.LeakyReLU
self.features = nn.Sequential(
conv_block(z, self.num_filters[0], False, self.convxd, self.norm, self.pooling, relu=self.relu),
conv_block(self.num_filters[0], self.num_filters[1], False, self.convxd, self.norm, self.pooling, relu=self.relu),
conv_block(self.num_filters[1], self.num_filters[2], True, self.convxd, self.norm, self.pooling, relu=self.relu),
conv_block(self.num_filters[2], self.num_filters[3], True, self.convxd, self.norm, self.pooling, relu=self.relu),
conv_block(self.num_filters[3], self.num_filters[4], True, self.convxd, self.norm, self.pooling, relu=self.relu),
)
a = (x // (2 ** np.shape(self.num_filters)[0])) * (y // (2 ** np.shape(self.num_filters)[0])) * self.num_filters[-1]
a = int(a)
N = 4096
self.fc1 = nn.Sequential(
nn.Linear(a, N),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(N, N),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(N, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size()[0], -1)
x = self.fc1(x)
return x