tempsal / src /model.py
baharay's picture
Upload 31 files
a3fc39c verified
import torchvision.models as models
import torch
import torch.nn as nn
from collections import OrderedDict
import sys
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from scipy import ndimage
sys.path.append('./PNAS/')
from PNASnet import *
from genotypes import PNASNet
import torch.nn.functional as nnf
import numpy as np
class PNASModel(nn.Module):
def __init__(self, num_channels=3, train_enc=False, load_weight=1):
super(PNASModel, self).__init__()
self.pnas = NetworkImageNet(216, 1001, 12, False, PNASNet)
if load_weight:
self.pnas.load_state_dict(torch.load(self.path))
for param in self.pnas.parameters():
param.requires_grad = train_enc
self.padding = nn.ConstantPad2d((0,1,0,1),0)
self.drop_path_prob = 0
self.linear_upsampling = nn.UpsamplingBilinear2d(scale_factor=2)
self.deconv_layer0 = nn.Sequential(
nn.Conv2d(in_channels = 4320, out_channels = 512, kernel_size=3, padding=1, bias = True),
nn.ReLU(inplace=True),
self.linear_upsampling
)
self.deconv_layer1 = nn.Sequential(
nn.Conv2d(in_channels = 512+2160, out_channels = 256, kernel_size = 3, padding = 1, bias = True),
nn.ReLU(inplace=True),
self.linear_upsampling
)
self.deconv_layer2 = nn.Sequential(
nn.Conv2d(in_channels = 1080+256, out_channels = 270, kernel_size = 3, padding = 1, bias = True),
nn.ReLU(inplace=True),
self.linear_upsampling
)
self.deconv_layer3 = nn.Sequential(
nn.Conv2d(in_channels = 540, out_channels = 96, kernel_size = 3, padding = 1, bias = True),
nn.ReLU(inplace=True),
self.linear_upsampling
)
self.deconv_layer4 = nn.Sequential(
nn.Conv2d(in_channels = 192, out_channels = 128, kernel_size = 3, padding = 1, bias = True),
nn.ReLU(inplace=True),
self.linear_upsampling
)
self.deconv_layer5 = nn.Sequential(
nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, padding = 1, bias = True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels = 128, out_channels = 1, kernel_size = 3, padding = 1, bias = True),
nn.Sigmoid()
)
def forward(self, images):
batch_size = images.size(0)
s0 = self.pnas.conv0(images)
s0 = self.pnas.conv0_bn(s0)
out1 = self.padding(s0)
s1 = self.pnas.stem1(s0, s0, self.drop_path_prob)
out2 = s1
s0, s1 = s1, self.pnas.stem2(s0, s1, 0)
for i, cell in enumerate(self.pnas.cells):
s0, s1 = s1, cell(s0, s1, 0)
if i==3:
out3 = s1
if i==7:
out4 = s1
if i==11:
out5 = s1
out5 = self.deconv_layer0(out5)
x = torch.cat((out5,out4), 1)
x = self.deconv_layer1(x)
x = torch.cat((x,out3), 1)
x = self.deconv_layer2(x)
x = torch.cat((x,out2), 1)
x = self.deconv_layer3(x)
x = torch.cat((x,out1), 1)
x = self.deconv_layer4(x)
x = self.deconv_layer5(x)
x = x.squeeze(1)
# print("PNAS pred actual pnas:", x.mean(),x.min(), x.max(), x.sum())
return x
class PNASVolModellast(nn.Module):
def __init__(self, time_slices, num_channels=3, train_enc=False, load_weight=1):
super(PNASVolModellast, self).__init__()
self.pnas = NetworkImageNet(216, 1001, 12, False, PNASNet)
if load_weight:
state_dict = torch.load(self.path)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if 'module' in k:
k = 'module.pnas.' + k
else:
k = k.replace('pnas.', '')
new_state_dict[k] = v
self.pnas.load_state_dict(new_state_dict, strict=False)
for param in self.pnas.parameters():
param.requires_grad = train_enc
self.padding = nn.ConstantPad2d((0,1,0,1),0)
self.drop_path_prob = 0
self.linear_upsampling = nn.UpsamplingBilinear2d(scale_factor=2)
self.deconv_layer0 = nn.Sequential(
nn.Conv2d(in_channels = 4320, out_channels = 512, kernel_size=3, padding=1, bias = True),
nn.ReLU(inplace=True),
self.linear_upsampling
)
self.deconv_layer1 = nn.Sequential(
nn.Conv2d(in_channels = 512+2160, out_channels = 256, kernel_size = 3, padding = 1, bias = True),
nn.ReLU(inplace=True),
self.linear_upsampling
)
self.deconv_layer2 = nn.Sequential(
nn.Conv2d(in_channels = 1080+256, out_channels = 270, kernel_size = 3, padding = 1, bias = True),
nn.ReLU(inplace=True),
self.linear_upsampling
)
self.deconv_layer3 = nn.Sequential(
nn.Conv2d(in_channels = 540, out_channels = 96, kernel_size = 3, padding = 1, bias = True),
nn.ReLU(inplace=True),
self.linear_upsampling
)
self.deconv_layer4 = nn.Sequential(
nn.Conv2d(in_channels = 192, out_channels = 128, kernel_size = 3, padding = 1, bias = True),
nn.ReLU(inplace=True),
self.linear_upsampling
)
self.deconv_layer5 = nn.Sequential(
nn.Conv2d(in_channels = 128, out_channels = 64, kernel_size = 3, padding = 1, bias = True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels = 64, out_channels = 32, kernel_size = 3, padding = 1, bias = True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels = 32, out_channels = time_slices, kernel_size = 3, padding = 1, bias = True),
nn.Sigmoid()
)
def forward(self, images):
s0 = self.pnas.conv0(images)
s0 = self.pnas.conv0_bn(s0)
out1 = self.padding(s0)
s1 = self.pnas.stem1(s0, s0, self.drop_path_prob)
out2 = s1
s0, s1 = s1, self.pnas.stem2(s0, s1, 0)
for i, cell in enumerate(self.pnas.cells):
s0, s1 = s1, cell(s0, s1, 0)
if i==3:
out3 = s1
if i==7:
out4 = s1
if i==11:
out5 = s1
out5 = self.deconv_layer0(out5)
x = torch.cat((out5,out4), 1)
x = self.deconv_layer1(x)
x = torch.cat((x,out3), 1)
x = self.deconv_layer2(x)
x = torch.cat((x,out2), 1)
x = self.deconv_layer3(x)
x = torch.cat((x,out1), 1)
x = self.deconv_layer4(x)
x = self.deconv_layer5(x)
x = x / x.max()
return x , [out1,out2,out3,out4,out5]
class PNASBoostedModelMultiLevel(nn.Module):
def __init__(self, device, model_path, model_vol_path, time_slices, train_model=False, selected_slices=""):
super(PNASBoostedModelMultiLevel, self).__init__()
self.selected_slices = selected_slices
self.linear_upsampling = nn.UpsamplingBilinear2d(scale_factor=2)
self.deconv_layer1 = nn.Sequential(
nn.Conv2d(in_channels = 512+2160+6, out_channels = 256, kernel_size = 3, padding = 1, bias = True),
nn.ReLU(inplace=True),
self.linear_upsampling
)
self.deconv_layer2 = nn.Sequential(
nn.Conv2d(in_channels = 1080+256+6, out_channels = 270, kernel_size = 3, padding = 1, bias = True),
nn.ReLU(inplace=True),
self.linear_upsampling
)
self.deconv_layer3 = nn.Sequential(
nn.Conv2d(in_channels = 540+6, out_channels = 96, kernel_size = 3, padding = 1, bias = True),
nn.ReLU(inplace=True),
self.linear_upsampling
)
self.deconv_layer4 = nn.Sequential(
nn.Conv2d(in_channels = 192+6, out_channels = 128, kernel_size = 3, padding = 1, bias = True),
nn.ReLU(inplace=True),
self.linear_upsampling
)
self.deconv_mix = nn.Sequential(
nn.Conv2d(in_channels = 128+6 , out_channels = 16, kernel_size = 3, padding = 1, bias = True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = 3, padding = 1, bias = True),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels = 32, out_channels = 1, kernel_size = 3, padding = 1, bias = True),
nn.Sigmoid()
)
model_vol = PNASVolModellast(time_slices=5, load_weight=0) #change this to time slices
model_vol = nn.DataParallel(model_vol).cuda()
state_dict = torch.load(model_path)
vol_state_dict = OrderedDict()
sal_state_dict = OrderedDict()
smm_state_dict = OrderedDict()
for k, v in state_dict.items():
if 'pnas_vol' in k:
k = k.replace('pnas_vol.module.', '')
vol_state_dict[k] = v
elif 'pnas_sal' in k:
k = k.replace('pnas_sal.module.', '')
sal_state_dict[k] = v
else:
smm_state_dict[k] = v
self.load_state_dict(smm_state_dict)
model_vol.load_state_dict(vol_state_dict)
self.pnas_vol = nn.DataParallel(model_vol).cuda()
for param in self.pnas_vol.parameters():
param.requires_grad = False
model = PNASModel(load_weight=0)
model = nn.DataParallel(model).cuda()
model.load_state_dict(sal_state_dict, strict=True)
self.pnas_sal = nn.DataParallel(model).to(device)
for param in self.pnas_sal.parameters():
param.requires_grad = False #train_model
def forward(self, images):
# print("IMAGES", images.shape)
pnas_pred = self.pnas_sal(images).unsqueeze(1)
pnas_vol_pred , outs = self.pnas_vol(images)
out1 , out2, out3, out4, out5 = outs
#print(pnas_vol_pred.shape)
x_maps = torch.cat((pnas_pred, pnas_vol_pred), 1)
x = torch.cat((out5,out4), 1)
x_maps16 = nnf.interpolate(x_maps, size=(16, 16), mode='bicubic', align_corners=False)
x = torch.cat((x,x_maps16), 1)
x = self.deconv_layer1(x)
x = torch.cat((x,out3), 1)
x_maps32 = nnf.interpolate(x_maps, size=(32, 32), mode='bicubic', align_corners=False)
x = torch.cat((x,x_maps32), 1)
x = self.deconv_layer2(x)
x = torch.cat((x,out2), 1)
x_maps64 = nnf.interpolate(x_maps, size=(64, 64), mode='bicubic', align_corners=False)
x = torch.cat((x,x_maps64), 1)
x = self.deconv_layer3(x)
x = torch.cat((x,out1), 1)
x_maps128 = nnf.interpolate(x_maps, size=(128, 128), mode='bicubic', align_corners=False)
x = torch.cat((x,x_maps128), 1)
x = self.deconv_layer4(x)
x = torch.cat((x,x_maps), 1)
x = self.deconv_mix(x)
x = x.squeeze(1)
return x, pnas_vol_pred