peirong26's picture
Upload 187 files
2571f24 verified
"""
Model heads
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
def removekey(d, keys):
r = dict(d)
for k in keys:
del r[k]
return r
class TaskHead(nn.Module):
"""
Task-specific head that takes a list of sample features as inputs
"""
def __init__(self, args, f_maps_list, out_channels, is_3d, out_feat_level = -1, exclude_keys=[], *kwargs):
super(TaskHead, self).__init__()
self.out_feat_level = out_feat_level
layers = [] # additional layers (same-size-output 3x3 conv) before final_conv, if len( f_maps_list ) > 1
for i, in_feature_num in enumerate(f_maps_list[:-1]):
layer = ConvBlock(in_feature_num, f_maps_list[i+1], stride = 1, is_3d = is_3d)
layers.append(layer)
self.layers = nn.ModuleList(layers)
conv = nn.Conv3d if is_3d else nn.Conv2d
fc = nn.Linear
self.out_channels = removekey(out_channels, exclude_keys)
self.out_names = self.out_channels.keys()
for out_name, out_channels_num in self.out_channels.items():
if out_channels_num > 0:
self.add_module("final_conv_%s" % out_name, conv(f_maps_list[-1], out_channels_num, 1))
else: # single value output (age prediction)
pool_layers = [nn.MaxPool3d(4, 4), # (160 -> 40)
ConvBlock(f_maps_list[-1], 16, stride = 1, is_3d = is_3d),
nn.MaxPool3d(4, 4), # (40 -> 10)
ConvBlock(16, 4, stride = 1, is_3d = is_3d)
]
self.pool_layers = nn.ModuleList(pool_layers)
self.add_module("final_linear1_%s" % out_name, fc(4 * args.size[0] // 16 * args.size[1] // 16 * args.size[2] // 16, 160, 1))
self.add_module("final_linear2_%s" % out_name, fc(160, 10, 1))
self.add_module("final_linear3_%s" % out_name, fc(10, - out_channels_num, 1))
def forward(self, x, *kwargs):
x = x[self.out_feat_level]
for layer in self.layers:
x = layer(x)
out = {}
for (name, n_channels) in self.out_channels.items():
if n_channels > 0:
out[name] = getattr(self, f"final_conv_{name}")(x)
else:
for layer in self.pool_layers:
x = layer(x)
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(getattr(self, f"final_linear1_{name}")(x))
x = F.relu(getattr(self, f"final_linear2_{name}")(x))
out[name] = torch.squeeze(getattr(self, f"final_linear3_{name}")(x), dim = 1)
return out
class DepHead(nn.Module):
"""
Task-specific head that takes a list of sample features as inputs
For contrast-dependent tasks
"""
def __init__(self, args, f_maps_list, out_channels, is_3d, out_feat_level = -1, *kwargs):
super(DepHead, self).__init__()
self.out_feat_level = out_feat_level
f_maps_list[0] += 1 # add one input image/contrast channel
layers = [] # additional layers (same-size-output 3x3 conv) before final_conv, if len( f_maps_list ) > 1
for i, in_feature_num in enumerate(f_maps_list[:-1]):
layer = ConvBlock(in_feature_num, f_maps_list[i+1], stride = 1, is_3d = is_3d)
layers.append(layer)
self.layers = nn.ModuleList(layers)
conv = nn.Conv3d if is_3d else nn.Conv2d
self.out_names = out_channels.keys()
for out_name, out_channels_num in out_channels.items():
self.add_module("final_conv_%s" % out_name, conv(f_maps_list[-1], out_channels_num, 1))
def forward(self, x, image):
x = x[self.out_feat_level]
x = torch.cat([x, image], dim = 1)
for layer in self.layers:
x = layer(x)
out = {}
for name in self.out_names:
out[name] = getattr(self, f"final_conv_{name}")(x)
return out
class MultiInputDepHead(DepHead):
"""
Task-specific head that takes a list of sample features as inputs
For contrast-dependent tasks
"""
def __init__(self, args, f_maps_list, out_channels, is_3d, out_feat_level = -1, *kwargs):
super(MultiInputDepHead, self).__init__(args, f_maps_list, out_channels, is_3d, out_feat_level)
def forward(self, feat_list, image_list):
outs = []
for i, x in enumerate(feat_list):
x = x[self.out_feat_level]
x = torch.cat([x, image_list[i]], dim = 1)
for layer in self.layers:
x = layer(x)
out = {}
for name in self.out_names:
out[name] = getattr(self, f"final_conv_{name}")(x)
outs.append(out)
return outs
class MultiInputTaskHead(TaskHead):
"""
Task-specific head that takes a list of sample features as inputs
For contrast-independent tasks
"""
def __init__(self, args, f_maps_list, out_channels, is_3d, out_feat_level = -1, *kwargs):
super(MultiInputTaskHead, self).__init__(args, f_maps_list, out_channels, is_3d, out_feat_level)
def forward(self, feat_list, *kwargs):
outs = []
for x in feat_list:
x = x[self.out_feat_level]
for layer in self.layers:
x = layer(x)
out = {}
for name in self.out_names:
out[name] = getattr(self, f"final_conv_{name}")(x)
outs.append(out)
return outs
class ConvBlock(nn.Module):
"""
Specific same-size-output 3x3 convolutional block followed by leakyrelu for unet.
"""
def __init__(self, in_channels, out_channels, stride=1, is_3d=True):
super().__init__()
conv = nn.Conv3d if is_3d else nn.Conv2d
self.main = conv(in_channels, out_channels, 3, stride, 1)
self.activation = nn.LeakyReLU(0.2)
def forward(self, x):
out = self.main(x)
out = self.activation(out)
return out
################################
def get_head(train_args, f_maps_list, out_channels, is_3d, out_feat_level, stage=0, exclude_keys=[]):
if 'sep' in train_args.backbone: # separate decoder and head for anomaly/pathology segmentation
return get_sep_head(train_args, f_maps_list, out_channels, is_3d, out_feat_level)
elif '+' in train_args.backbone: # two-stage network for inpainting
if stage == 0:
return TaskHead(train_args, f_maps_list, {'pathology': 1}, is_3d, out_feat_level)
else:
return TaskHead(train_args, f_maps_list, out_channels, is_3d, out_feat_level, exclude_keys = ['pathology'])
return TaskHead(train_args, f_maps_list, out_channels, is_3d, out_feat_level, exclude_keys)
def get_sep_head(train_args, f_maps_list, out_channels, is_3d, out_feat_level):
head_normal = TaskHead(train_args, f_maps_list, out_channels, is_3d, out_feat_level, ['pathology'])
head_pathol = TaskHead(train_args, f_maps_list, {'pathology': 1}, is_3d, out_feat_level)
return {'normal': head_normal, 'pathology': head_pathol}