|
|
""" |
|
|
Model heads |
|
|
""" |
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
def removekey(d, keys): |
|
|
r = dict(d) |
|
|
for k in keys: |
|
|
del r[k] |
|
|
return r |
|
|
|
|
|
|
|
|
class TaskHead(nn.Module): |
|
|
""" |
|
|
Task-specific head that takes a list of sample features as inputs |
|
|
""" |
|
|
|
|
|
def __init__(self, args, f_maps_list, out_channels, is_3d, out_feat_level = -1, exclude_keys=[], *kwargs): |
|
|
super(TaskHead, self).__init__() |
|
|
self.out_feat_level = out_feat_level |
|
|
|
|
|
layers = [] |
|
|
for i, in_feature_num in enumerate(f_maps_list[:-1]): |
|
|
layer = ConvBlock(in_feature_num, f_maps_list[i+1], stride = 1, is_3d = is_3d) |
|
|
layers.append(layer) |
|
|
self.layers = nn.ModuleList(layers) |
|
|
|
|
|
conv = nn.Conv3d if is_3d else nn.Conv2d |
|
|
fc = nn.Linear |
|
|
|
|
|
self.out_channels = removekey(out_channels, exclude_keys) |
|
|
self.out_names = self.out_channels.keys() |
|
|
for out_name, out_channels_num in self.out_channels.items(): |
|
|
if out_channels_num > 0: |
|
|
self.add_module("final_conv_%s" % out_name, conv(f_maps_list[-1], out_channels_num, 1)) |
|
|
else: |
|
|
pool_layers = [nn.MaxPool3d(4, 4), |
|
|
ConvBlock(f_maps_list[-1], 16, stride = 1, is_3d = is_3d), |
|
|
nn.MaxPool3d(4, 4), |
|
|
ConvBlock(16, 4, stride = 1, is_3d = is_3d) |
|
|
] |
|
|
self.pool_layers = nn.ModuleList(pool_layers) |
|
|
self.add_module("final_linear1_%s" % out_name, fc(4 * args.size[0] // 16 * args.size[1] // 16 * args.size[2] // 16, 160, 1)) |
|
|
self.add_module("final_linear2_%s" % out_name, fc(160, 10, 1)) |
|
|
self.add_module("final_linear3_%s" % out_name, fc(10, - out_channels_num, 1)) |
|
|
|
|
|
def forward(self, x, *kwargs): |
|
|
x = x[self.out_feat_level] |
|
|
for layer in self.layers: |
|
|
x = layer(x) |
|
|
out = {} |
|
|
for (name, n_channels) in self.out_channels.items(): |
|
|
if n_channels > 0: |
|
|
out[name] = getattr(self, f"final_conv_{name}")(x) |
|
|
else: |
|
|
for layer in self.pool_layers: |
|
|
x = layer(x) |
|
|
x = torch.flatten(x, 1) |
|
|
x = F.relu(getattr(self, f"final_linear1_{name}")(x)) |
|
|
x = F.relu(getattr(self, f"final_linear2_{name}")(x)) |
|
|
out[name] = torch.squeeze(getattr(self, f"final_linear3_{name}")(x), dim = 1) |
|
|
return out |
|
|
|
|
|
|
|
|
class DepHead(nn.Module): |
|
|
""" |
|
|
Task-specific head that takes a list of sample features as inputs |
|
|
For contrast-dependent tasks |
|
|
""" |
|
|
|
|
|
def __init__(self, args, f_maps_list, out_channels, is_3d, out_feat_level = -1, *kwargs): |
|
|
super(DepHead, self).__init__() |
|
|
self.out_feat_level = out_feat_level |
|
|
|
|
|
f_maps_list[0] += 1 |
|
|
|
|
|
layers = [] |
|
|
for i, in_feature_num in enumerate(f_maps_list[:-1]): |
|
|
layer = ConvBlock(in_feature_num, f_maps_list[i+1], stride = 1, is_3d = is_3d) |
|
|
layers.append(layer) |
|
|
self.layers = nn.ModuleList(layers) |
|
|
|
|
|
conv = nn.Conv3d if is_3d else nn.Conv2d |
|
|
self.out_names = out_channels.keys() |
|
|
for out_name, out_channels_num in out_channels.items(): |
|
|
self.add_module("final_conv_%s" % out_name, conv(f_maps_list[-1], out_channels_num, 1)) |
|
|
|
|
|
def forward(self, x, image): |
|
|
x = x[self.out_feat_level] |
|
|
x = torch.cat([x, image], dim = 1) |
|
|
for layer in self.layers: |
|
|
x = layer(x) |
|
|
out = {} |
|
|
for name in self.out_names: |
|
|
out[name] = getattr(self, f"final_conv_{name}")(x) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class MultiInputDepHead(DepHead): |
|
|
""" |
|
|
Task-specific head that takes a list of sample features as inputs |
|
|
For contrast-dependent tasks |
|
|
""" |
|
|
|
|
|
def __init__(self, args, f_maps_list, out_channels, is_3d, out_feat_level = -1, *kwargs): |
|
|
super(MultiInputDepHead, self).__init__(args, f_maps_list, out_channels, is_3d, out_feat_level) |
|
|
|
|
|
def forward(self, feat_list, image_list): |
|
|
outs = [] |
|
|
for i, x in enumerate(feat_list): |
|
|
x = x[self.out_feat_level] |
|
|
x = torch.cat([x, image_list[i]], dim = 1) |
|
|
for layer in self.layers: |
|
|
x = layer(x) |
|
|
out = {} |
|
|
for name in self.out_names: |
|
|
out[name] = getattr(self, f"final_conv_{name}")(x) |
|
|
outs.append(out) |
|
|
return outs |
|
|
|
|
|
|
|
|
|
|
|
class MultiInputTaskHead(TaskHead): |
|
|
""" |
|
|
Task-specific head that takes a list of sample features as inputs |
|
|
For contrast-independent tasks |
|
|
""" |
|
|
|
|
|
def __init__(self, args, f_maps_list, out_channels, is_3d, out_feat_level = -1, *kwargs): |
|
|
super(MultiInputTaskHead, self).__init__(args, f_maps_list, out_channels, is_3d, out_feat_level) |
|
|
|
|
|
def forward(self, feat_list, *kwargs): |
|
|
outs = [] |
|
|
for x in feat_list: |
|
|
x = x[self.out_feat_level] |
|
|
for layer in self.layers: |
|
|
x = layer(x) |
|
|
out = {} |
|
|
for name in self.out_names: |
|
|
out[name] = getattr(self, f"final_conv_{name}")(x) |
|
|
outs.append(out) |
|
|
return outs |
|
|
|
|
|
|
|
|
|
|
|
class ConvBlock(nn.Module): |
|
|
""" |
|
|
Specific same-size-output 3x3 convolutional block followed by leakyrelu for unet. |
|
|
""" |
|
|
|
|
|
def __init__(self, in_channels, out_channels, stride=1, is_3d=True): |
|
|
super().__init__() |
|
|
|
|
|
conv = nn.Conv3d if is_3d else nn.Conv2d |
|
|
self.main = conv(in_channels, out_channels, 3, stride, 1) |
|
|
self.activation = nn.LeakyReLU(0.2) |
|
|
|
|
|
def forward(self, x): |
|
|
out = self.main(x) |
|
|
out = self.activation(out) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_head(train_args, f_maps_list, out_channels, is_3d, out_feat_level, stage=0, exclude_keys=[]): |
|
|
if 'sep' in train_args.backbone: |
|
|
return get_sep_head(train_args, f_maps_list, out_channels, is_3d, out_feat_level) |
|
|
elif '+' in train_args.backbone: |
|
|
if stage == 0: |
|
|
return TaskHead(train_args, f_maps_list, {'pathology': 1}, is_3d, out_feat_level) |
|
|
else: |
|
|
return TaskHead(train_args, f_maps_list, out_channels, is_3d, out_feat_level, exclude_keys = ['pathology']) |
|
|
return TaskHead(train_args, f_maps_list, out_channels, is_3d, out_feat_level, exclude_keys) |
|
|
|
|
|
|
|
|
def get_sep_head(train_args, f_maps_list, out_channels, is_3d, out_feat_level): |
|
|
head_normal = TaskHead(train_args, f_maps_list, out_channels, is_3d, out_feat_level, ['pathology']) |
|
|
head_pathol = TaskHead(train_args, f_maps_list, {'pathology': 1}, is_3d, out_feat_level) |
|
|
return {'normal': head_normal, 'pathology': head_pathol} |