|
|
|
|
|
|
|
|
|
|
|
from collections import OrderedDict |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision.ops import MLP |
|
|
import torchvision.models as models |
|
|
from typing import Dict, Optional |
|
|
|
|
|
|
|
|
class BaseNet: |
|
|
""" |
|
|
Class to construct network |
|
|
""" |
|
|
cnn = { |
|
|
'ResNet18': models.resnet18, |
|
|
'ResNet': models.resnet50, |
|
|
'DenseNet': models.densenet161, |
|
|
'EfficientNetB0': models.efficientnet_b0, |
|
|
'EfficientNetB2': models.efficientnet_b2, |
|
|
'EfficientNetB4': models.efficientnet_b4, |
|
|
'EfficientNetB6': models.efficientnet_b6, |
|
|
'EfficientNetV2s': models.efficientnet_v2_s, |
|
|
'EfficientNetV2m': models.efficientnet_v2_m, |
|
|
'EfficientNetV2l': models.efficientnet_v2_l, |
|
|
'ConvNeXtTiny': models.convnext_tiny, |
|
|
'ConvNeXtSmall': models.convnext_small, |
|
|
'ConvNeXtBase': models.convnext_base, |
|
|
'ConvNeXtLarge': models.convnext_large |
|
|
} |
|
|
|
|
|
vit = { |
|
|
'ViTb16': models.vit_b_16, |
|
|
'ViTb32': models.vit_b_32, |
|
|
'ViTl16': models.vit_l_16, |
|
|
'ViTl32': models.vit_l_32, |
|
|
'ViTH14': models.vit_h_14 |
|
|
} |
|
|
|
|
|
net = {**cnn, **vit} |
|
|
|
|
|
_classifier = { |
|
|
'ResNet': 'fc', |
|
|
'DenseNet': 'classifier', |
|
|
'EfficientNet': 'classifier', |
|
|
'ConvNext': 'classifier', |
|
|
'ViT': 'heads' |
|
|
} |
|
|
|
|
|
classifier = { |
|
|
'ResNet18': _classifier['ResNet'], |
|
|
'ResNet': _classifier['ResNet'], |
|
|
'DenseNet': _classifier['DenseNet'], |
|
|
'EfficientNetB0': _classifier['EfficientNet'], |
|
|
'EfficientNetB2': _classifier['EfficientNet'], |
|
|
'EfficientNetB4': _classifier['EfficientNet'], |
|
|
'EfficientNetB6': _classifier['EfficientNet'], |
|
|
'EfficientNetV2s': _classifier['EfficientNet'], |
|
|
'EfficientNetV2m': _classifier['EfficientNet'], |
|
|
'EfficientNetV2l': _classifier['EfficientNet'], |
|
|
'ConvNeXtTiny': _classifier['ConvNext'], |
|
|
'ConvNeXtSmall': _classifier['ConvNext'], |
|
|
'ConvNeXtBase': _classifier['ConvNext'], |
|
|
'ConvNeXtLarge': _classifier['ConvNext'], |
|
|
'ViTb16': _classifier['ViT'], |
|
|
'ViTb32': _classifier['ViT'], |
|
|
'ViTl16': _classifier['ViT'], |
|
|
'ViTl32': _classifier['ViT'], |
|
|
'ViTH14': _classifier['ViT'] |
|
|
} |
|
|
|
|
|
mlp_config = { |
|
|
'hidden_channels': [256, 256, 256], |
|
|
'dropout': 0.2 |
|
|
} |
|
|
|
|
|
DUMMY = nn.Identity() |
|
|
|
|
|
@classmethod |
|
|
def MLPNet(cls, mlp_num_inputs: int = None, inplace: bool = None) -> MLP: |
|
|
""" |
|
|
Construct MLP. |
|
|
|
|
|
Args: |
|
|
mlp_num_inputs (int): the number of input of MLP |
|
|
inplace (bool, optional): parameter for the activation layer, which can optionally do the operation in-place. Defaults to None. |
|
|
|
|
|
Returns: |
|
|
MLP: MLP |
|
|
""" |
|
|
assert isinstance(mlp_num_inputs, int), f"Invalid number of inputs for MLP: {mlp_num_inputs}." |
|
|
mlp = MLP(in_channels=mlp_num_inputs, hidden_channels=cls.mlp_config['hidden_channels'], inplace=inplace, dropout=cls.mlp_config['dropout']) |
|
|
return mlp |
|
|
|
|
|
@classmethod |
|
|
def align_in_channels_1ch(cls, net_name: str = None, net: nn.Module = None) -> nn.Module: |
|
|
""" |
|
|
Modify network to handle gray scale image. |
|
|
|
|
|
Args: |
|
|
net_name (str): network name |
|
|
net (nn.Module): network itself |
|
|
|
|
|
Returns: |
|
|
nn.Module: network available for gray scale |
|
|
""" |
|
|
if net_name.startswith('ResNet'): |
|
|
net.conv1.in_channels = 1 |
|
|
net.conv1.weight = nn.Parameter(net.conv1.weight.sum(dim=1).unsqueeze(1)) |
|
|
|
|
|
elif net_name.startswith('DenseNet'): |
|
|
net.features.conv0.in_channels = 1 |
|
|
net.features.conv0.weight = nn.Parameter(net.features.conv0.weight.sum(dim=1).unsqueeze(1)) |
|
|
|
|
|
elif net_name.startswith('Efficient'): |
|
|
net.features[0][0].in_channels = 1 |
|
|
net.features[0][0].weight = nn.Parameter(net.features[0][0].weight.sum(dim=1).unsqueeze(1)) |
|
|
|
|
|
elif net_name.startswith('ConvNeXt'): |
|
|
net.features[0][0].in_channels = 1 |
|
|
net.features[0][0].weight = nn.Parameter(net.features[0][0].weight.sum(dim=1).unsqueeze(1)) |
|
|
|
|
|
elif net_name.startswith('ViT'): |
|
|
net.conv_proj.in_channels = 1 |
|
|
net.conv_proj.weight = nn.Parameter(net.conv_proj.weight.sum(dim=1).unsqueeze(1)) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"No specified net: {net_name}.") |
|
|
return net |
|
|
|
|
|
@classmethod |
|
|
def set_net( |
|
|
cls, |
|
|
net_name: str = None, |
|
|
in_channel: int = None, |
|
|
vit_image_size: int = None, |
|
|
pretrained: bool = None |
|
|
) -> nn.Module: |
|
|
""" |
|
|
Modify network depending on in_channel and vit_image_size. |
|
|
|
|
|
Args: |
|
|
net_name (str): network name |
|
|
in_channel (int, optional): image channel(any of 1ch or 3ch). Defaults to None. |
|
|
vit_image_size (int, optional): image size which ViT handles if ViT is used. Defaults to None. |
|
|
vit_image_size should be power of patch size. |
|
|
pretrained (bool, optional): True when use pretrained CNN or ViT, otherwise False. Defaults to None. |
|
|
|
|
|
Returns: |
|
|
nn.Module: modified network |
|
|
""" |
|
|
assert net_name in cls.net, f"No specified net: {net_name}." |
|
|
if net_name in cls.cnn: |
|
|
if pretrained: |
|
|
net = cls.cnn[net_name](weights='DEFAULT') |
|
|
else: |
|
|
net = cls.cnn[net_name]() |
|
|
else: |
|
|
|
|
|
|
|
|
net = cls.set_vit(net_name=net_name, vit_image_size=vit_image_size) |
|
|
|
|
|
if in_channel == 1: |
|
|
net = cls.align_in_channels_1ch(net_name=net_name, net=net) |
|
|
return net |
|
|
|
|
|
@classmethod |
|
|
def set_vit(cls, net_name: str = None, vit_image_size: int = None) -> nn.Module: |
|
|
""" |
|
|
Modify ViT depending on vit_image_size. |
|
|
|
|
|
Args: |
|
|
net_name (str): ViT name |
|
|
vit_image_size (int): image size which ViT handles if ViT is used. |
|
|
|
|
|
Returns: |
|
|
nn.Module: modified ViT |
|
|
""" |
|
|
base_vit = cls.vit[net_name] |
|
|
|
|
|
pretrained_vit = base_vit(weights='DEFAULT') |
|
|
|
|
|
|
|
|
weight = pretrained_vit.state_dict() |
|
|
patch_size = int(net_name[-2:]) |
|
|
aligned_weight = models.vision_transformer.interpolate_embeddings( |
|
|
image_size=vit_image_size, |
|
|
patch_size=patch_size, |
|
|
model_state=weight |
|
|
) |
|
|
aligned_vit = base_vit(image_size=vit_image_size) |
|
|
aligned_vit.load_state_dict(aligned_weight) |
|
|
return aligned_vit |
|
|
|
|
|
@classmethod |
|
|
def construct_extractor( |
|
|
cls, |
|
|
net_name: str = None, |
|
|
mlp_num_inputs: int = None, |
|
|
in_channel: int = None, |
|
|
vit_image_size: int = None, |
|
|
pretrained: bool = None |
|
|
) -> nn.Module: |
|
|
""" |
|
|
Construct extractor of network depending on net_name. |
|
|
|
|
|
Args: |
|
|
net_name (str): network name. |
|
|
mlp_num_inputs (int, optional): number of input of MLP. Defaults to None. |
|
|
in_channel (int, optional): image channel(any of 1ch or 3ch). Defaults to None. |
|
|
vit_image_size (int, optional): image size which ViT handles if ViT is used. Defaults to None. |
|
|
pretrained (bool, optional): True when use pretrained CNN or ViT, otherwise False. Defaults to None. |
|
|
|
|
|
Returns: |
|
|
nn.Module: extractor of network |
|
|
""" |
|
|
if net_name == 'MLP': |
|
|
extractor = cls.MLPNet(mlp_num_inputs=mlp_num_inputs) |
|
|
else: |
|
|
extractor = cls.set_net(net_name=net_name, in_channel=in_channel, vit_image_size=vit_image_size, pretrained=pretrained) |
|
|
setattr(extractor, cls.classifier[net_name], cls.DUMMY) |
|
|
return extractor |
|
|
|
|
|
@classmethod |
|
|
def get_classifier(cls, net_name: str) -> nn.Module: |
|
|
""" |
|
|
Get classifier of network depending on net_name. |
|
|
|
|
|
Args: |
|
|
net_name (str): network name |
|
|
|
|
|
Returns: |
|
|
nn.Module: classifier of network |
|
|
""" |
|
|
net = cls.net[net_name]() |
|
|
classifier = getattr(net, cls.classifier[net_name]) |
|
|
return classifier |
|
|
|
|
|
@classmethod |
|
|
def construct_multi_classifier(cls, net_name: str = None, num_outputs_for_label: Dict[str, int] = None) -> nn.ModuleDict: |
|
|
""" |
|
|
Construct classifier for multi-label. |
|
|
|
|
|
Args: |
|
|
net_name (str): network name |
|
|
num_outputs_for_label (Dict[str, int]): number of outputs for each label |
|
|
|
|
|
Returns: |
|
|
nn.ModuleDict: classifier for multi-label |
|
|
""" |
|
|
classifiers = dict() |
|
|
if net_name == 'MLP': |
|
|
in_features = cls.mlp_config['hidden_channels'][-1] |
|
|
for label_name, num_outputs in num_outputs_for_label.items(): |
|
|
classifiers[label_name] = nn.Linear(in_features, num_outputs) |
|
|
|
|
|
elif net_name.startswith('ResNet') or net_name.startswith('DenseNet'): |
|
|
base_classifier = cls.get_classifier(net_name) |
|
|
in_features = base_classifier.in_features |
|
|
for label_name, num_outputs in num_outputs_for_label.items(): |
|
|
classifiers[label_name] = nn.Linear(in_features, num_outputs) |
|
|
|
|
|
elif net_name.startswith('EfficientNet'): |
|
|
base_classifier = cls.get_classifier(net_name) |
|
|
dropout = base_classifier[0].p |
|
|
in_features = base_classifier[1].in_features |
|
|
for label_name, num_outputs in num_outputs_for_label.items(): |
|
|
classifiers[label_name] = nn.Sequential( |
|
|
nn.Dropout(p=dropout, inplace=False), |
|
|
nn.Linear(in_features, num_outputs) |
|
|
) |
|
|
|
|
|
elif net_name.startswith('ConvNeXt'): |
|
|
base_classifier = cls.get_classifier(net_name) |
|
|
layer_norm = base_classifier[0] |
|
|
flatten = base_classifier[1] |
|
|
in_features = base_classifier[2].in_features |
|
|
for label_name, num_outputs in num_outputs_for_label.items(): |
|
|
|
|
|
classifiers[label_name] = nn.Sequential( |
|
|
layer_norm, |
|
|
flatten, |
|
|
nn.Linear(in_features, num_outputs) |
|
|
) |
|
|
|
|
|
elif net_name.startswith('ViT'): |
|
|
base_classifier = cls.get_classifier(net_name) |
|
|
in_features = base_classifier.head.in_features |
|
|
for label_name, num_outputs in num_outputs_for_label.items(): |
|
|
classifiers[label_name] = nn.Sequential( |
|
|
OrderedDict([ |
|
|
('head', nn.Linear(in_features, num_outputs)) |
|
|
]) |
|
|
) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"No specified net: {net_name}.") |
|
|
|
|
|
multi_classifier = nn.ModuleDict(classifiers) |
|
|
return multi_classifier |
|
|
|
|
|
@classmethod |
|
|
def get_classifier_in_features(cls, net_name: str) -> int: |
|
|
""" |
|
|
Return in_feature of network indicating by net_name. |
|
|
This class is used in class MultiNetFusion() only. |
|
|
|
|
|
Args: |
|
|
net_name (str): net_name |
|
|
|
|
|
Returns: |
|
|
int : in_feature |
|
|
|
|
|
Required: |
|
|
classifier.in_feature |
|
|
classifier.[1].in_features |
|
|
classifier.[2].in_features |
|
|
classifier.head.in_features |
|
|
""" |
|
|
if net_name == 'MLP': |
|
|
in_features = cls.mlp_config['hidden_channels'][-1] |
|
|
|
|
|
elif net_name.startswith('ResNet') or net_name.startswith('DenseNet'): |
|
|
base_classifier = cls.get_classifier(net_name) |
|
|
in_features = base_classifier.in_features |
|
|
|
|
|
elif net_name.startswith('EfficientNet'): |
|
|
base_classifier = cls.get_classifier(net_name) |
|
|
in_features = base_classifier[1].in_features |
|
|
|
|
|
elif net_name.startswith('ConvNeXt'): |
|
|
base_classifier = cls.get_classifier(net_name) |
|
|
in_features = base_classifier[2].in_features |
|
|
|
|
|
elif net_name.startswith('ViT'): |
|
|
base_classifier = cls.get_classifier(net_name) |
|
|
in_features = base_classifier.head.in_features |
|
|
|
|
|
else: |
|
|
raise ValueError(f"No specified net: {net_name}.") |
|
|
return in_features |
|
|
|
|
|
@classmethod |
|
|
def construct_aux_module(cls, net_name: str) -> nn.Sequential: |
|
|
""" |
|
|
Construct module to align the shape of feature from extractor depending on network. |
|
|
Actually, only when net_name == 'ConvNeXt'. |
|
|
Because ConvNeXt has the process of aligning the dimensions in its classifier. |
|
|
|
|
|
Needs to align shape of the feature extractor when ConvNeXt |
|
|
(classifier): |
|
|
Sequential( |
|
|
(0): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True) |
|
|
(1): Flatten(start_dim=1, end_dim=-1) |
|
|
(2): Linear(in_features=768, out_features=1000, bias=True) |
|
|
) |
|
|
|
|
|
Args: |
|
|
net_name (str): net name |
|
|
|
|
|
Returns: |
|
|
nn.Module: layers such that they align the dimension of the output from the extractor like the original ConvNeXt. |
|
|
""" |
|
|
aux_module = cls.DUMMY |
|
|
if net_name.startswith('ConvNeXt'): |
|
|
base_classifier = cls.get_classifier(net_name) |
|
|
layer_norm = base_classifier[0] |
|
|
flatten = base_classifier[1] |
|
|
aux_module = nn.Sequential( |
|
|
layer_norm, |
|
|
flatten |
|
|
) |
|
|
return aux_module |
|
|
|
|
|
@classmethod |
|
|
def get_last_extractor(cls, net: nn.Module = None, mlp: str = None, net_name: str = None) -> nn.Module: |
|
|
""" |
|
|
Return the last extractor of network. |
|
|
This is for Grad-CAM. |
|
|
net should be one loaded weight. |
|
|
|
|
|
Args: |
|
|
net (nn.Module): network itself |
|
|
mlp (str): 'MLP', otherwise None |
|
|
net_name (str): network name |
|
|
|
|
|
Returns: |
|
|
nn.Module: last extractor of network |
|
|
""" |
|
|
assert (net_name is not None), f"Network does not contain CNN or ViT: mlp={mlp}, net={net_name}." |
|
|
|
|
|
_extractor = net.extractor_net |
|
|
|
|
|
if net_name.startswith('ResNet'): |
|
|
last_extractor = _extractor.layer4[-1] |
|
|
elif net_name.startswith('DenseNet'): |
|
|
last_extractor = _extractor.features.denseblock4.denselayer24 |
|
|
elif net_name.startswith('EfficientNet'): |
|
|
last_extractor = _extractor.features[-1] |
|
|
elif net_name.startswith('ConvNeXt'): |
|
|
last_extractor = _extractor.features[-1][-1].block |
|
|
elif net_name.startswith('ViT'): |
|
|
last_extractor = _extractor.encoder.layers[-1] |
|
|
else: |
|
|
raise ValueError(f"Cannot get last extractor of net: {net_name}.") |
|
|
return last_extractor |
|
|
|
|
|
|
|
|
class MultiMixin: |
|
|
""" |
|
|
Class to define auxiliary function to handle multi-label. |
|
|
""" |
|
|
def multi_forward(self, out_features: int) -> Dict[str, float]: |
|
|
""" |
|
|
Forward out_features to classifier for each label. |
|
|
|
|
|
Args: |
|
|
out_features (int): output from extractor |
|
|
|
|
|
Returns: |
|
|
Dict[str, float]: output of classifier of each label |
|
|
""" |
|
|
output = dict() |
|
|
for label_name, classifier in self.multi_classifier.items(): |
|
|
output[label_name] = classifier(out_features) |
|
|
return output |
|
|
|
|
|
|
|
|
class MultiWidget(nn.Module, BaseNet, MultiMixin): |
|
|
""" |
|
|
Class for a widget to inherit multiple classes simultaneously. |
|
|
""" |
|
|
pass |
|
|
|
|
|
|
|
|
class MultiNet(MultiWidget): |
|
|
""" |
|
|
Model of MLP, CNN or ViT. |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
net_name: str = None, |
|
|
num_outputs_for_label: Dict[str, int] = None, |
|
|
mlp_num_inputs: int = None, |
|
|
in_channel: int = None, |
|
|
vit_image_size: int = None, |
|
|
pretrained: bool = None |
|
|
) -> None: |
|
|
""" |
|
|
Args: |
|
|
net_name (str): MLP, CNN or ViT name |
|
|
num_outputs_for_label (Dict[str, int]): number of classes for each label |
|
|
mlp_num_inputs (int): number of input of MLP. |
|
|
in_channel (int): number of image channel, ie gray scale(=1) or color image(=3). |
|
|
vit_image_size (int): image size to be input to ViT. |
|
|
pretrained (bool): True when use pretrained CNN or ViT, otherwise False. |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.net_name = net_name |
|
|
self.num_outputs_for_label = num_outputs_for_label |
|
|
self.mlp_num_inputs = mlp_num_inputs |
|
|
self.in_channel = in_channel |
|
|
self.vit_image_size = vit_image_size |
|
|
self.pretrained = pretrained |
|
|
|
|
|
|
|
|
self.extractor_net = self.construct_extractor( |
|
|
net_name=self.net_name, |
|
|
mlp_num_inputs=self.mlp_num_inputs, |
|
|
in_channel=self.in_channel, |
|
|
vit_image_size=self.vit_image_size, |
|
|
pretrained=self.pretrained |
|
|
) |
|
|
self.multi_classifier = self.construct_multi_classifier(net_name=self.net_name, num_outputs_for_label=self.num_outputs_for_label) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Forward. |
|
|
|
|
|
Args: |
|
|
x (torch.Tensor): tabular data or image |
|
|
|
|
|
Returns: |
|
|
Dict[str, torch.Tensor]: output |
|
|
""" |
|
|
out_features = self.extractor_net(x) |
|
|
output = self.multi_forward(out_features) |
|
|
return output |
|
|
|
|
|
|
|
|
class MultiNetFusion(MultiWidget): |
|
|
""" |
|
|
Fusion model of MLP and CNN or ViT. |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
net_name: str = None, |
|
|
num_outputs_for_label: Dict[str, int] = None, |
|
|
mlp_num_inputs: int = None, |
|
|
in_channel: int = None, |
|
|
vit_image_size: int = None, |
|
|
pretrained: bool = None |
|
|
) -> None: |
|
|
""" |
|
|
Args: |
|
|
net_name (str): CNN or ViT name. It is clear that MLP is used in fusion model. |
|
|
num_outputs_for_label (Dict[str, int]): number of classes for each label |
|
|
mlp_num_inputs (int): number of input of MLP. Defaults to None. |
|
|
in_channel (int): number of image channel, ie gray scale(=1) or color image(=3). |
|
|
vit_image_size (int): image size to be input to ViT. |
|
|
pretrained (bool): True when use pretrained CNN or ViT, otherwise False. |
|
|
""" |
|
|
assert (net_name != 'MLP'), 'net_name should not be MLP.' |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
self.net_name = net_name |
|
|
self.num_outputs_for_label = num_outputs_for_label |
|
|
self.mlp_num_inputs = mlp_num_inputs |
|
|
self.in_channel = in_channel |
|
|
self.vit_image_size = vit_image_size |
|
|
self.pretrained = pretrained |
|
|
|
|
|
|
|
|
self.extractor_mlp = self.construct_extractor(net_name='MLP', mlp_num_inputs=self.mlp_num_inputs) |
|
|
self.extractor_net = self.construct_extractor( |
|
|
net_name=self.net_name, |
|
|
in_channel=self.in_channel, |
|
|
vit_image_size=self.vit_image_size, |
|
|
pretrained=self.pretrained |
|
|
) |
|
|
self.aux_module = self.construct_aux_module(self.net_name) |
|
|
|
|
|
|
|
|
self.in_features_from_mlp = self.get_classifier_in_features('MLP') |
|
|
self.in_features_from_net = self.get_classifier_in_features(self.net_name) |
|
|
self.inter_mlp_in_feature = self.in_features_from_mlp + self.in_features_from_net |
|
|
self.inter_mlp = self.MLPNet(mlp_num_inputs=self.inter_mlp_in_feature, inplace=False) |
|
|
|
|
|
|
|
|
self.multi_classifier = self.construct_multi_classifier(net_name='MLP', num_outputs_for_label=num_outputs_for_label) |
|
|
|
|
|
def forward(self, x_mlp: torch.Tensor, x_net: torch.Tensor) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Forward. |
|
|
|
|
|
Args: |
|
|
x_mlp (torch.Tensor): tabular data |
|
|
x_net (torch.Tensor): image |
|
|
|
|
|
Returns: |
|
|
Dict[str, torch.Tensor]: output |
|
|
""" |
|
|
out_mlp = self.extractor_mlp(x_mlp) |
|
|
out_net = self.extractor_net(x_net) |
|
|
out_net = self.aux_module(out_net) |
|
|
|
|
|
out_features = torch.cat([out_mlp, out_net], dim=1) |
|
|
out_features = self.inter_mlp(out_features) |
|
|
output = self.multi_forward(out_features) |
|
|
return output |
|
|
|
|
|
|
|
|
def create_net( |
|
|
mlp: Optional[str] = None, |
|
|
net: Optional[str] = None, |
|
|
num_outputs_for_label: Dict[str, int] = None, |
|
|
mlp_num_inputs: int = None, |
|
|
in_channel: int = None, |
|
|
vit_image_size: int = None, |
|
|
pretrained: bool = None |
|
|
) -> nn.Module: |
|
|
""" |
|
|
Create network. |
|
|
|
|
|
Args: |
|
|
mlp (Optional[str]): 'MLP' or None |
|
|
net (Optional[str]): CNN, ViT name or None |
|
|
num_outputs_for_label (Dict[str, int]): number of outputs for each label |
|
|
mlp_num_inputs (int): number of input of MLP. |
|
|
in_channel (int): number of image channel, ie gray scale(=1) or color image(=3). |
|
|
vit_image_size (int): image size to be input to ViT. |
|
|
pretrained (bool): True when use pretrained CNN or ViT, otherwise False. |
|
|
|
|
|
Returns: |
|
|
nn.Module: network |
|
|
""" |
|
|
_isMLPModel = (mlp is not None) and (net is None) |
|
|
_isCVModel = (mlp is None) and (net is not None) |
|
|
_isFusion = (mlp is not None) and (net is not None) |
|
|
|
|
|
if _isMLPModel: |
|
|
multi_net = MultiNet( |
|
|
net_name='MLP', |
|
|
num_outputs_for_label=num_outputs_for_label, |
|
|
mlp_num_inputs=mlp_num_inputs, |
|
|
in_channel=in_channel, |
|
|
vit_image_size=vit_image_size, |
|
|
pretrained=False |
|
|
) |
|
|
elif _isCVModel: |
|
|
multi_net = MultiNet( |
|
|
net_name=net, |
|
|
num_outputs_for_label=num_outputs_for_label, |
|
|
mlp_num_inputs=mlp_num_inputs, |
|
|
in_channel=in_channel, |
|
|
vit_image_size=vit_image_size, |
|
|
pretrained=pretrained |
|
|
) |
|
|
elif _isFusion: |
|
|
multi_net = MultiNetFusion( |
|
|
net_name=net, |
|
|
num_outputs_for_label=num_outputs_for_label, |
|
|
mlp_num_inputs=mlp_num_inputs, |
|
|
in_channel=in_channel, |
|
|
vit_image_size=vit_image_size, |
|
|
pretrained=pretrained |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Invalid model type: mlp={mlp}, net={net}.") |
|
|
|
|
|
return multi_net |
|
|
|