| | |
| | |
| |
|
| | from collections import OrderedDict |
| | import torch |
| | import torch.nn as nn |
| | from torchvision.ops import MLP |
| | import torchvision.models as models |
| | from typing import Dict, Optional |
| |
|
| |
|
| | class BaseNet: |
| | """ |
| | Class to construct network |
| | """ |
| | cnn = { |
| | 'ResNet18': models.resnet18, |
| | 'ResNet': models.resnet50, |
| | 'DenseNet': models.densenet161, |
| | 'EfficientNetB0': models.efficientnet_b0, |
| | 'EfficientNetB2': models.efficientnet_b2, |
| | 'EfficientNetB4': models.efficientnet_b4, |
| | 'EfficientNetB6': models.efficientnet_b6, |
| | 'EfficientNetV2s': models.efficientnet_v2_s, |
| | 'EfficientNetV2m': models.efficientnet_v2_m, |
| | 'EfficientNetV2l': models.efficientnet_v2_l, |
| | 'ConvNeXtTiny': models.convnext_tiny, |
| | 'ConvNeXtSmall': models.convnext_small, |
| | 'ConvNeXtBase': models.convnext_base, |
| | 'ConvNeXtLarge': models.convnext_large |
| | } |
| |
|
| | vit = { |
| | 'ViTb16': models.vit_b_16, |
| | 'ViTb32': models.vit_b_32, |
| | 'ViTl16': models.vit_l_16, |
| | 'ViTl32': models.vit_l_32, |
| | 'ViTH14': models.vit_h_14 |
| | } |
| |
|
| | net = {**cnn, **vit} |
| |
|
| | _classifier = { |
| | 'ResNet': 'fc', |
| | 'DenseNet': 'classifier', |
| | 'EfficientNet': 'classifier', |
| | 'ConvNext': 'classifier', |
| | 'ViT': 'heads' |
| | } |
| |
|
| | classifier = { |
| | 'ResNet18': _classifier['ResNet'], |
| | 'ResNet': _classifier['ResNet'], |
| | 'DenseNet': _classifier['DenseNet'], |
| | 'EfficientNetB0': _classifier['EfficientNet'], |
| | 'EfficientNetB2': _classifier['EfficientNet'], |
| | 'EfficientNetB4': _classifier['EfficientNet'], |
| | 'EfficientNetB6': _classifier['EfficientNet'], |
| | 'EfficientNetV2s': _classifier['EfficientNet'], |
| | 'EfficientNetV2m': _classifier['EfficientNet'], |
| | 'EfficientNetV2l': _classifier['EfficientNet'], |
| | 'ConvNeXtTiny': _classifier['ConvNext'], |
| | 'ConvNeXtSmall': _classifier['ConvNext'], |
| | 'ConvNeXtBase': _classifier['ConvNext'], |
| | 'ConvNeXtLarge': _classifier['ConvNext'], |
| | 'ViTb16': _classifier['ViT'], |
| | 'ViTb32': _classifier['ViT'], |
| | 'ViTl16': _classifier['ViT'], |
| | 'ViTl32': _classifier['ViT'], |
| | 'ViTH14': _classifier['ViT'] |
| | } |
| |
|
| | mlp_config = { |
| | 'hidden_channels': [256, 256, 256], |
| | 'dropout': 0.2 |
| | } |
| |
|
| | DUMMY = nn.Identity() |
| |
|
| | @classmethod |
| | def MLPNet(cls, mlp_num_inputs: int = None, inplace: bool = None) -> MLP: |
| | """ |
| | Construct MLP. |
| | |
| | Args: |
| | mlp_num_inputs (int): the number of input of MLP |
| | inplace (bool, optional): parameter for the activation layer, which can optionally do the operation in-place. Defaults to None. |
| | |
| | Returns: |
| | MLP: MLP |
| | """ |
| | assert isinstance(mlp_num_inputs, int), f"Invalid number of inputs for MLP: {mlp_num_inputs}." |
| | mlp = MLP(in_channels=mlp_num_inputs, hidden_channels=cls.mlp_config['hidden_channels'], inplace=inplace, dropout=cls.mlp_config['dropout']) |
| | return mlp |
| |
|
| | @classmethod |
| | def align_in_channels_1ch(cls, net_name: str = None, net: nn.Module = None) -> nn.Module: |
| | """ |
| | Modify network to handle gray scale image. |
| | |
| | Args: |
| | net_name (str): network name |
| | net (nn.Module): network itself |
| | |
| | Returns: |
| | nn.Module: network available for gray scale |
| | """ |
| | if net_name.startswith('ResNet'): |
| | net.conv1.in_channels = 1 |
| | net.conv1.weight = nn.Parameter(net.conv1.weight.sum(dim=1).unsqueeze(1)) |
| |
|
| | elif net_name.startswith('DenseNet'): |
| | net.features.conv0.in_channels = 1 |
| | net.features.conv0.weight = nn.Parameter(net.features.conv0.weight.sum(dim=1).unsqueeze(1)) |
| |
|
| | elif net_name.startswith('Efficient'): |
| | net.features[0][0].in_channels = 1 |
| | net.features[0][0].weight = nn.Parameter(net.features[0][0].weight.sum(dim=1).unsqueeze(1)) |
| |
|
| | elif net_name.startswith('ConvNeXt'): |
| | net.features[0][0].in_channels = 1 |
| | net.features[0][0].weight = nn.Parameter(net.features[0][0].weight.sum(dim=1).unsqueeze(1)) |
| |
|
| | elif net_name.startswith('ViT'): |
| | net.conv_proj.in_channels = 1 |
| | net.conv_proj.weight = nn.Parameter(net.conv_proj.weight.sum(dim=1).unsqueeze(1)) |
| |
|
| | else: |
| | raise ValueError(f"No specified net: {net_name}.") |
| | return net |
| |
|
| | @classmethod |
| | def set_net( |
| | cls, |
| | net_name: str = None, |
| | in_channel: int = None, |
| | vit_image_size: int = None, |
| | pretrained: bool = None |
| | ) -> nn.Module: |
| | """ |
| | Modify network depending on in_channel and vit_image_size. |
| | |
| | Args: |
| | net_name (str): network name |
| | in_channel (int, optional): image channel(any of 1ch or 3ch). Defaults to None. |
| | vit_image_size (int, optional): image size which ViT handles if ViT is used. Defaults to None. |
| | vit_image_size should be power of patch size. |
| | pretrained (bool, optional): True when use pretrained CNN or ViT, otherwise False. Defaults to None. |
| | |
| | Returns: |
| | nn.Module: modified network |
| | """ |
| | assert net_name in cls.net, f"No specified net: {net_name}." |
| | if net_name in cls.cnn: |
| | if pretrained: |
| | net = cls.cnn[net_name](weights='DEFAULT') |
| | else: |
| | net = cls.cnn[net_name]() |
| | else: |
| | |
| | |
| | net = cls.set_vit(net_name=net_name, vit_image_size=vit_image_size) |
| |
|
| | if in_channel == 1: |
| | net = cls.align_in_channels_1ch(net_name=net_name, net=net) |
| | return net |
| |
|
| | @classmethod |
| | def set_vit(cls, net_name: str = None, vit_image_size: int = None) -> nn.Module: |
| | """ |
| | Modify ViT depending on vit_image_size. |
| | |
| | Args: |
| | net_name (str): ViT name |
| | vit_image_size (int): image size which ViT handles if ViT is used. |
| | |
| | Returns: |
| | nn.Module: modified ViT |
| | """ |
| | base_vit = cls.vit[net_name] |
| | |
| | pretrained_vit = base_vit(weights='DEFAULT') |
| |
|
| | |
| | weight = pretrained_vit.state_dict() |
| | patch_size = int(net_name[-2:]) |
| | aligned_weight = models.vision_transformer.interpolate_embeddings( |
| | image_size=vit_image_size, |
| | patch_size=patch_size, |
| | model_state=weight |
| | ) |
| | aligned_vit = base_vit(image_size=vit_image_size) |
| | aligned_vit.load_state_dict(aligned_weight) |
| | return aligned_vit |
| |
|
| | @classmethod |
| | def construct_extractor( |
| | cls, |
| | net_name: str = None, |
| | mlp_num_inputs: int = None, |
| | in_channel: int = None, |
| | vit_image_size: int = None, |
| | pretrained: bool = None |
| | ) -> nn.Module: |
| | """ |
| | Construct extractor of network depending on net_name. |
| | |
| | Args: |
| | net_name (str): network name. |
| | mlp_num_inputs (int, optional): number of input of MLP. Defaults to None. |
| | in_channel (int, optional): image channel(any of 1ch or 3ch). Defaults to None. |
| | vit_image_size (int, optional): image size which ViT handles if ViT is used. Defaults to None. |
| | pretrained (bool, optional): True when use pretrained CNN or ViT, otherwise False. Defaults to None. |
| | |
| | Returns: |
| | nn.Module: extractor of network |
| | """ |
| | if net_name == 'MLP': |
| | extractor = cls.MLPNet(mlp_num_inputs=mlp_num_inputs) |
| | else: |
| | extractor = cls.set_net(net_name=net_name, in_channel=in_channel, vit_image_size=vit_image_size, pretrained=pretrained) |
| | setattr(extractor, cls.classifier[net_name], cls.DUMMY) |
| | return extractor |
| |
|
| | @classmethod |
| | def get_classifier(cls, net_name: str) -> nn.Module: |
| | """ |
| | Get classifier of network depending on net_name. |
| | |
| | Args: |
| | net_name (str): network name |
| | |
| | Returns: |
| | nn.Module: classifier of network |
| | """ |
| | net = cls.net[net_name]() |
| | classifier = getattr(net, cls.classifier[net_name]) |
| | return classifier |
| |
|
| | @classmethod |
| | def construct_multi_classifier(cls, net_name: str = None, num_outputs_for_label: Dict[str, int] = None) -> nn.ModuleDict: |
| | """ |
| | Construct classifier for multi-label. |
| | |
| | Args: |
| | net_name (str): network name |
| | num_outputs_for_label (Dict[str, int]): number of outputs for each label |
| | |
| | Returns: |
| | nn.ModuleDict: classifier for multi-label |
| | """ |
| | classifiers = dict() |
| | if net_name == 'MLP': |
| | in_features = cls.mlp_config['hidden_channels'][-1] |
| | for label_name, num_outputs in num_outputs_for_label.items(): |
| | classifiers[label_name] = nn.Linear(in_features, num_outputs) |
| |
|
| | elif net_name.startswith('ResNet') or net_name.startswith('DenseNet'): |
| | base_classifier = cls.get_classifier(net_name) |
| | in_features = base_classifier.in_features |
| | for label_name, num_outputs in num_outputs_for_label.items(): |
| | classifiers[label_name] = nn.Linear(in_features, num_outputs) |
| |
|
| | elif net_name.startswith('EfficientNet'): |
| | base_classifier = cls.get_classifier(net_name) |
| | dropout = base_classifier[0].p |
| | in_features = base_classifier[1].in_features |
| | for label_name, num_outputs in num_outputs_for_label.items(): |
| | classifiers[label_name] = nn.Sequential( |
| | nn.Dropout(p=dropout, inplace=False), |
| | nn.Linear(in_features, num_outputs) |
| | ) |
| |
|
| | elif net_name.startswith('ConvNeXt'): |
| | base_classifier = cls.get_classifier(net_name) |
| | layer_norm = base_classifier[0] |
| | flatten = base_classifier[1] |
| | in_features = base_classifier[2].in_features |
| | for label_name, num_outputs in num_outputs_for_label.items(): |
| | |
| | classifiers[label_name] = nn.Sequential( |
| | layer_norm, |
| | flatten, |
| | nn.Linear(in_features, num_outputs) |
| | ) |
| |
|
| | elif net_name.startswith('ViT'): |
| | base_classifier = cls.get_classifier(net_name) |
| | in_features = base_classifier.head.in_features |
| | for label_name, num_outputs in num_outputs_for_label.items(): |
| | classifiers[label_name] = nn.Sequential( |
| | OrderedDict([ |
| | ('head', nn.Linear(in_features, num_outputs)) |
| | ]) |
| | ) |
| |
|
| | else: |
| | raise ValueError(f"No specified net: {net_name}.") |
| |
|
| | multi_classifier = nn.ModuleDict(classifiers) |
| | return multi_classifier |
| |
|
| | @classmethod |
| | def get_classifier_in_features(cls, net_name: str) -> int: |
| | """ |
| | Return in_feature of network indicating by net_name. |
| | This class is used in class MultiNetFusion() only. |
| | |
| | Args: |
| | net_name (str): net_name |
| | |
| | Returns: |
| | int : in_feature |
| | |
| | Required: |
| | classifier.in_feature |
| | classifier.[1].in_features |
| | classifier.[2].in_features |
| | classifier.head.in_features |
| | """ |
| | if net_name == 'MLP': |
| | in_features = cls.mlp_config['hidden_channels'][-1] |
| |
|
| | elif net_name.startswith('ResNet') or net_name.startswith('DenseNet'): |
| | base_classifier = cls.get_classifier(net_name) |
| | in_features = base_classifier.in_features |
| |
|
| | elif net_name.startswith('EfficientNet'): |
| | base_classifier = cls.get_classifier(net_name) |
| | in_features = base_classifier[1].in_features |
| |
|
| | elif net_name.startswith('ConvNeXt'): |
| | base_classifier = cls.get_classifier(net_name) |
| | in_features = base_classifier[2].in_features |
| |
|
| | elif net_name.startswith('ViT'): |
| | base_classifier = cls.get_classifier(net_name) |
| | in_features = base_classifier.head.in_features |
| |
|
| | else: |
| | raise ValueError(f"No specified net: {net_name}.") |
| | return in_features |
| |
|
| | @classmethod |
| | def construct_aux_module(cls, net_name: str) -> nn.Sequential: |
| | """ |
| | Construct module to align the shape of feature from extractor depending on network. |
| | Actually, only when net_name == 'ConvNeXt'. |
| | Because ConvNeXt has the process of aligning the dimensions in its classifier. |
| | |
| | Needs to align shape of the feature extractor when ConvNeXt |
| | (classifier): |
| | Sequential( |
| | (0): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True) |
| | (1): Flatten(start_dim=1, end_dim=-1) |
| | (2): Linear(in_features=768, out_features=1000, bias=True) |
| | ) |
| | |
| | Args: |
| | net_name (str): net name |
| | |
| | Returns: |
| | nn.Module: layers such that they align the dimension of the output from the extractor like the original ConvNeXt. |
| | """ |
| | aux_module = cls.DUMMY |
| | if net_name.startswith('ConvNeXt'): |
| | base_classifier = cls.get_classifier(net_name) |
| | layer_norm = base_classifier[0] |
| | flatten = base_classifier[1] |
| | aux_module = nn.Sequential( |
| | layer_norm, |
| | flatten |
| | ) |
| | return aux_module |
| |
|
| | @classmethod |
| | def get_last_extractor(cls, net: nn.Module = None, mlp: str = None, net_name: str = None) -> nn.Module: |
| | """ |
| | Return the last extractor of network. |
| | This is for Grad-CAM. |
| | net should be one loaded weight. |
| | |
| | Args: |
| | net (nn.Module): network itself |
| | mlp (str): 'MLP', otherwise None |
| | net_name (str): network name |
| | |
| | Returns: |
| | nn.Module: last extractor of network |
| | """ |
| | assert (net_name is not None), f"Network does not contain CNN or ViT: mlp={mlp}, net={net_name}." |
| |
|
| | _extractor = net.extractor_net |
| |
|
| | if net_name.startswith('ResNet'): |
| | last_extractor = _extractor.layer4[-1] |
| | elif net_name.startswith('DenseNet'): |
| | last_extractor = _extractor.features.denseblock4.denselayer24 |
| | elif net_name.startswith('EfficientNet'): |
| | last_extractor = _extractor.features[-1] |
| | elif net_name.startswith('ConvNeXt'): |
| | last_extractor = _extractor.features[-1][-1].block |
| | elif net_name.startswith('ViT'): |
| | last_extractor = _extractor.encoder.layers[-1] |
| | else: |
| | raise ValueError(f"Cannot get last extractor of net: {net_name}.") |
| | return last_extractor |
| |
|
| |
|
| | class MultiMixin: |
| | """ |
| | Class to define auxiliary function to handle multi-label. |
| | """ |
| | def multi_forward(self, out_features: int) -> Dict[str, float]: |
| | """ |
| | Forward out_features to classifier for each label. |
| | |
| | Args: |
| | out_features (int): output from extractor |
| | |
| | Returns: |
| | Dict[str, float]: output of classifier of each label |
| | """ |
| | output = dict() |
| | for label_name, classifier in self.multi_classifier.items(): |
| | output[label_name] = classifier(out_features) |
| | return output |
| |
|
| |
|
| | class MultiWidget(nn.Module, BaseNet, MultiMixin): |
| | """ |
| | Class for a widget to inherit multiple classes simultaneously. |
| | """ |
| | pass |
| |
|
| |
|
| | class MultiNet(MultiWidget): |
| | """ |
| | Model of MLP, CNN or ViT. |
| | """ |
| | def __init__( |
| | self, |
| | net_name: str = None, |
| | num_outputs_for_label: Dict[str, int] = None, |
| | mlp_num_inputs: int = None, |
| | in_channel: int = None, |
| | vit_image_size: int = None, |
| | pretrained: bool = None |
| | ) -> None: |
| | """ |
| | Args: |
| | net_name (str): MLP, CNN or ViT name |
| | num_outputs_for_label (Dict[str, int]): number of classes for each label |
| | mlp_num_inputs (int): number of input of MLP. |
| | in_channel (int): number of image channel, ie gray scale(=1) or color image(=3). |
| | vit_image_size (int): image size to be input to ViT. |
| | pretrained (bool): True when use pretrained CNN or ViT, otherwise False. |
| | """ |
| | super().__init__() |
| |
|
| | self.net_name = net_name |
| | self.num_outputs_for_label = num_outputs_for_label |
| | self.mlp_num_inputs = mlp_num_inputs |
| | self.in_channel = in_channel |
| | self.vit_image_size = vit_image_size |
| | self.pretrained = pretrained |
| |
|
| | |
| | self.extractor_net = self.construct_extractor( |
| | net_name=self.net_name, |
| | mlp_num_inputs=self.mlp_num_inputs, |
| | in_channel=self.in_channel, |
| | vit_image_size=self.vit_image_size, |
| | pretrained=self.pretrained |
| | ) |
| | self.multi_classifier = self.construct_multi_classifier(net_name=self.net_name, num_outputs_for_label=self.num_outputs_for_label) |
| |
|
| | def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: |
| | """ |
| | Forward. |
| | |
| | Args: |
| | x (torch.Tensor): tabular data or image |
| | |
| | Returns: |
| | Dict[str, torch.Tensor]: output |
| | """ |
| | out_features = self.extractor_net(x) |
| | output = self.multi_forward(out_features) |
| | return output |
| |
|
| |
|
| | class MultiNetFusion(MultiWidget): |
| | """ |
| | Fusion model of MLP and CNN or ViT. |
| | """ |
| | def __init__( |
| | self, |
| | net_name: str = None, |
| | num_outputs_for_label: Dict[str, int] = None, |
| | mlp_num_inputs: int = None, |
| | in_channel: int = None, |
| | vit_image_size: int = None, |
| | pretrained: bool = None |
| | ) -> None: |
| | """ |
| | Args: |
| | net_name (str): CNN or ViT name. It is clear that MLP is used in fusion model. |
| | num_outputs_for_label (Dict[str, int]): number of classes for each label |
| | mlp_num_inputs (int): number of input of MLP. Defaults to None. |
| | in_channel (int): number of image channel, ie gray scale(=1) or color image(=3). |
| | vit_image_size (int): image size to be input to ViT. |
| | pretrained (bool): True when use pretrained CNN or ViT, otherwise False. |
| | """ |
| | assert (net_name != 'MLP'), 'net_name should not be MLP.' |
| |
|
| | super().__init__() |
| |
|
| | self.net_name = net_name |
| | self.num_outputs_for_label = num_outputs_for_label |
| | self.mlp_num_inputs = mlp_num_inputs |
| | self.in_channel = in_channel |
| | self.vit_image_size = vit_image_size |
| | self.pretrained = pretrained |
| |
|
| | |
| | self.extractor_mlp = self.construct_extractor(net_name='MLP', mlp_num_inputs=self.mlp_num_inputs) |
| | self.extractor_net = self.construct_extractor( |
| | net_name=self.net_name, |
| | in_channel=self.in_channel, |
| | vit_image_size=self.vit_image_size, |
| | pretrained=self.pretrained |
| | ) |
| | self.aux_module = self.construct_aux_module(self.net_name) |
| |
|
| | |
| | self.in_features_from_mlp = self.get_classifier_in_features('MLP') |
| | self.in_features_from_net = self.get_classifier_in_features(self.net_name) |
| | self.inter_mlp_in_feature = self.in_features_from_mlp + self.in_features_from_net |
| | self.inter_mlp = self.MLPNet(mlp_num_inputs=self.inter_mlp_in_feature, inplace=False) |
| |
|
| | |
| | self.multi_classifier = self.construct_multi_classifier(net_name='MLP', num_outputs_for_label=num_outputs_for_label) |
| |
|
| | def forward(self, x_mlp: torch.Tensor, x_net: torch.Tensor) -> Dict[str, torch.Tensor]: |
| | """ |
| | Forward. |
| | |
| | Args: |
| | x_mlp (torch.Tensor): tabular data |
| | x_net (torch.Tensor): image |
| | |
| | Returns: |
| | Dict[str, torch.Tensor]: output |
| | """ |
| | out_mlp = self.extractor_mlp(x_mlp) |
| | out_net = self.extractor_net(x_net) |
| | out_net = self.aux_module(out_net) |
| |
|
| | out_features = torch.cat([out_mlp, out_net], dim=1) |
| | out_features = self.inter_mlp(out_features) |
| | output = self.multi_forward(out_features) |
| | return output |
| |
|
| |
|
| | def create_net( |
| | mlp: Optional[str] = None, |
| | net: Optional[str] = None, |
| | num_outputs_for_label: Dict[str, int] = None, |
| | mlp_num_inputs: int = None, |
| | in_channel: int = None, |
| | vit_image_size: int = None, |
| | pretrained: bool = None |
| | ) -> nn.Module: |
| | """ |
| | Create network. |
| | |
| | Args: |
| | mlp (Optional[str]): 'MLP' or None |
| | net (Optional[str]): CNN, ViT name or None |
| | num_outputs_for_label (Dict[str, int]): number of outputs for each label |
| | mlp_num_inputs (int): number of input of MLP. |
| | in_channel (int): number of image channel, ie gray scale(=1) or color image(=3). |
| | vit_image_size (int): image size to be input to ViT. |
| | pretrained (bool): True when use pretrained CNN or ViT, otherwise False. |
| | |
| | Returns: |
| | nn.Module: network |
| | """ |
| | _isMLPModel = (mlp is not None) and (net is None) |
| | _isCVModel = (mlp is None) and (net is not None) |
| | _isFusion = (mlp is not None) and (net is not None) |
| |
|
| | if _isMLPModel: |
| | multi_net = MultiNet( |
| | net_name='MLP', |
| | num_outputs_for_label=num_outputs_for_label, |
| | mlp_num_inputs=mlp_num_inputs, |
| | in_channel=in_channel, |
| | vit_image_size=vit_image_size, |
| | pretrained=False |
| | ) |
| | elif _isCVModel: |
| | multi_net = MultiNet( |
| | net_name=net, |
| | num_outputs_for_label=num_outputs_for_label, |
| | mlp_num_inputs=mlp_num_inputs, |
| | in_channel=in_channel, |
| | vit_image_size=vit_image_size, |
| | pretrained=pretrained |
| | ) |
| | elif _isFusion: |
| | multi_net = MultiNetFusion( |
| | net_name=net, |
| | num_outputs_for_label=num_outputs_for_label, |
| | mlp_num_inputs=mlp_num_inputs, |
| | in_channel=in_channel, |
| | vit_image_size=vit_image_size, |
| | pretrained=pretrained |
| | ) |
| | else: |
| | raise ValueError(f"Invalid model type: mlp={mlp}, net={net}.") |
| |
|
| | return multi_net |
| |
|