Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*-r | |
| from collections import OrderedDict | |
| import torch | |
| import torch.nn as nn | |
| from torchvision.ops import MLP | |
| import torchvision.models as models | |
| from typing import Dict, Optional | |
| class BaseNet: | |
| """ | |
| Class to construct network | |
| """ | |
| cnn = { | |
| 'ResNet18': models.resnet18, | |
| 'ResNet': models.resnet50, | |
| 'DenseNet': models.densenet161, | |
| 'EfficientNetB0': models.efficientnet_b0, | |
| 'EfficientNetB2': models.efficientnet_b2, | |
| 'EfficientNetB4': models.efficientnet_b4, | |
| 'EfficientNetB6': models.efficientnet_b6, | |
| 'EfficientNetV2s': models.efficientnet_v2_s, | |
| 'EfficientNetV2m': models.efficientnet_v2_m, | |
| 'EfficientNetV2l': models.efficientnet_v2_l, | |
| 'ConvNeXtTiny': models.convnext_tiny, | |
| 'ConvNeXtSmall': models.convnext_small, | |
| 'ConvNeXtBase': models.convnext_base, | |
| 'ConvNeXtLarge': models.convnext_large | |
| } | |
| vit = { | |
| 'ViTb16': models.vit_b_16, | |
| 'ViTb32': models.vit_b_32, | |
| 'ViTl16': models.vit_l_16, | |
| 'ViTl32': models.vit_l_32, | |
| 'ViTH14': models.vit_h_14 | |
| } | |
| net = {**cnn, **vit} | |
| _classifier = { | |
| 'ResNet': 'fc', | |
| 'DenseNet': 'classifier', | |
| 'EfficientNet': 'classifier', | |
| 'ConvNext': 'classifier', | |
| 'ViT': 'heads' | |
| } | |
| classifier = { | |
| 'ResNet18': _classifier['ResNet'], | |
| 'ResNet': _classifier['ResNet'], | |
| 'DenseNet': _classifier['DenseNet'], | |
| 'EfficientNetB0': _classifier['EfficientNet'], | |
| 'EfficientNetB2': _classifier['EfficientNet'], | |
| 'EfficientNetB4': _classifier['EfficientNet'], | |
| 'EfficientNetB6': _classifier['EfficientNet'], | |
| 'EfficientNetV2s': _classifier['EfficientNet'], | |
| 'EfficientNetV2m': _classifier['EfficientNet'], | |
| 'EfficientNetV2l': _classifier['EfficientNet'], | |
| 'ConvNeXtTiny': _classifier['ConvNext'], | |
| 'ConvNeXtSmall': _classifier['ConvNext'], | |
| 'ConvNeXtBase': _classifier['ConvNext'], | |
| 'ConvNeXtLarge': _classifier['ConvNext'], | |
| 'ViTb16': _classifier['ViT'], | |
| 'ViTb32': _classifier['ViT'], | |
| 'ViTl16': _classifier['ViT'], | |
| 'ViTl32': _classifier['ViT'], | |
| 'ViTH14': _classifier['ViT'] | |
| } | |
| mlp_config = { | |
| 'hidden_channels': [256, 256, 256], | |
| 'dropout': 0.2 | |
| } | |
| DUMMY = nn.Identity() | |
| def MLPNet(cls, mlp_num_inputs: int = None, inplace: bool = None) -> MLP: | |
| """ | |
| Construct MLP. | |
| Args: | |
| mlp_num_inputs (int): the number of input of MLP | |
| inplace (bool, optional): parameter for the activation layer, which can optionally do the operation in-place. Defaults to None. | |
| Returns: | |
| MLP: MLP | |
| """ | |
| assert isinstance(mlp_num_inputs, int), f"Invalid number of inputs for MLP: {mlp_num_inputs}." | |
| mlp = MLP(in_channels=mlp_num_inputs, hidden_channels=cls.mlp_config['hidden_channels'], inplace=inplace, dropout=cls.mlp_config['dropout']) | |
| return mlp | |
| def align_in_channels_1ch(cls, net_name: str = None, net: nn.Module = None) -> nn.Module: | |
| """ | |
| Modify network to handle gray scale image. | |
| Args: | |
| net_name (str): network name | |
| net (nn.Module): network itself | |
| Returns: | |
| nn.Module: network available for gray scale | |
| """ | |
| if net_name.startswith('ResNet'): | |
| net.conv1.in_channels = 1 | |
| net.conv1.weight = nn.Parameter(net.conv1.weight.sum(dim=1).unsqueeze(1)) | |
| elif net_name.startswith('DenseNet'): | |
| net.features.conv0.in_channels = 1 | |
| net.features.conv0.weight = nn.Parameter(net.features.conv0.weight.sum(dim=1).unsqueeze(1)) | |
| elif net_name.startswith('Efficient'): | |
| net.features[0][0].in_channels = 1 | |
| net.features[0][0].weight = nn.Parameter(net.features[0][0].weight.sum(dim=1).unsqueeze(1)) | |
| elif net_name.startswith('ConvNeXt'): | |
| net.features[0][0].in_channels = 1 | |
| net.features[0][0].weight = nn.Parameter(net.features[0][0].weight.sum(dim=1).unsqueeze(1)) | |
| elif net_name.startswith('ViT'): | |
| net.conv_proj.in_channels = 1 | |
| net.conv_proj.weight = nn.Parameter(net.conv_proj.weight.sum(dim=1).unsqueeze(1)) | |
| else: | |
| raise ValueError(f"No specified net: {net_name}.") | |
| return net | |
| def set_net( | |
| cls, | |
| net_name: str = None, | |
| in_channel: int = None, | |
| vit_image_size: int = None, | |
| pretrained: bool = None | |
| ) -> nn.Module: | |
| """ | |
| Modify network depending on in_channel and vit_image_size. | |
| Args: | |
| net_name (str): network name | |
| in_channel (int, optional): image channel(any of 1ch or 3ch). Defaults to None. | |
| vit_image_size (int, optional): image size which ViT handles if ViT is used. Defaults to None. | |
| vit_image_size should be power of patch size. | |
| pretrained (bool, optional): True when use pretrained CNN or ViT, otherwise False. Defaults to None. | |
| Returns: | |
| nn.Module: modified network | |
| """ | |
| assert net_name in cls.net, f"No specified net: {net_name}." | |
| if net_name in cls.cnn: | |
| if pretrained: | |
| net = cls.cnn[net_name](weights='DEFAULT') | |
| else: | |
| net = cls.cnn[net_name]() | |
| else: | |
| # When ViT | |
| # always use pretrained | |
| net = cls.set_vit(net_name=net_name, vit_image_size=vit_image_size) | |
| if in_channel == 1: | |
| net = cls.align_in_channels_1ch(net_name=net_name, net=net) | |
| return net | |
| def set_vit(cls, net_name: str = None, vit_image_size: int = None) -> nn.Module: | |
| """ | |
| Modify ViT depending on vit_image_size. | |
| Args: | |
| net_name (str): ViT name | |
| vit_image_size (int): image size which ViT handles if ViT is used. | |
| Returns: | |
| nn.Module: modified ViT | |
| """ | |
| base_vit = cls.vit[net_name] | |
| # pretrained_vit = base_vit(weights=cls.vit_weight[net_name]) | |
| pretrained_vit = base_vit(weights='DEFAULT') | |
| # Align weight depending on image size | |
| weight = pretrained_vit.state_dict() | |
| patch_size = int(net_name[-2:]) # 'ViTb16' -> 16 | |
| aligned_weight = models.vision_transformer.interpolate_embeddings( | |
| image_size=vit_image_size, | |
| patch_size=patch_size, | |
| model_state=weight | |
| ) | |
| aligned_vit = base_vit(image_size=vit_image_size) # Specify new image size. | |
| aligned_vit.load_state_dict(aligned_weight) # Load weight which can handle the new image size. | |
| return aligned_vit | |
| def construct_extractor( | |
| cls, | |
| net_name: str = None, | |
| mlp_num_inputs: int = None, | |
| in_channel: int = None, | |
| vit_image_size: int = None, | |
| pretrained: bool = None | |
| ) -> nn.Module: | |
| """ | |
| Construct extractor of network depending on net_name. | |
| Args: | |
| net_name (str): network name. | |
| mlp_num_inputs (int, optional): number of input of MLP. Defaults to None. | |
| in_channel (int, optional): image channel(any of 1ch or 3ch). Defaults to None. | |
| vit_image_size (int, optional): image size which ViT handles if ViT is used. Defaults to None. | |
| pretrained (bool, optional): True when use pretrained CNN or ViT, otherwise False. Defaults to None. | |
| Returns: | |
| nn.Module: extractor of network | |
| """ | |
| if net_name == 'MLP': | |
| extractor = cls.MLPNet(mlp_num_inputs=mlp_num_inputs) | |
| else: | |
| extractor = cls.set_net(net_name=net_name, in_channel=in_channel, vit_image_size=vit_image_size, pretrained=pretrained) | |
| setattr(extractor, cls.classifier[net_name], cls.DUMMY) # Replace classifier with DUMMY(=nn.Identity()). | |
| return extractor | |
| def get_classifier(cls, net_name: str) -> nn.Module: | |
| """ | |
| Get classifier of network depending on net_name. | |
| Args: | |
| net_name (str): network name | |
| Returns: | |
| nn.Module: classifier of network | |
| """ | |
| net = cls.net[net_name]() | |
| classifier = getattr(net, cls.classifier[net_name]) | |
| return classifier | |
| def construct_multi_classifier(cls, net_name: str = None, num_outputs_for_label: Dict[str, int] = None) -> nn.ModuleDict: | |
| """ | |
| Construct classifier for multi-label. | |
| Args: | |
| net_name (str): network name | |
| num_outputs_for_label (Dict[str, int]): number of outputs for each label | |
| Returns: | |
| nn.ModuleDict: classifier for multi-label | |
| """ | |
| classifiers = dict() | |
| if net_name == 'MLP': | |
| in_features = cls.mlp_config['hidden_channels'][-1] | |
| for label_name, num_outputs in num_outputs_for_label.items(): | |
| classifiers[label_name] = nn.Linear(in_features, num_outputs) | |
| elif net_name.startswith('ResNet') or net_name.startswith('DenseNet'): | |
| base_classifier = cls.get_classifier(net_name) | |
| in_features = base_classifier.in_features | |
| for label_name, num_outputs in num_outputs_for_label.items(): | |
| classifiers[label_name] = nn.Linear(in_features, num_outputs) | |
| elif net_name.startswith('EfficientNet'): | |
| base_classifier = cls.get_classifier(net_name) | |
| dropout = base_classifier[0].p | |
| in_features = base_classifier[1].in_features | |
| for label_name, num_outputs in num_outputs_for_label.items(): | |
| classifiers[label_name] = nn.Sequential( | |
| nn.Dropout(p=dropout, inplace=False), | |
| nn.Linear(in_features, num_outputs) | |
| ) | |
| elif net_name.startswith('ConvNeXt'): | |
| base_classifier = cls.get_classifier(net_name) | |
| layer_norm = base_classifier[0] | |
| flatten = base_classifier[1] | |
| in_features = base_classifier[2].in_features | |
| for label_name, num_outputs in num_outputs_for_label.items(): | |
| # Shape is changed before nn.Linear. | |
| classifiers[label_name] = nn.Sequential( | |
| layer_norm, | |
| flatten, | |
| nn.Linear(in_features, num_outputs) | |
| ) | |
| elif net_name.startswith('ViT'): | |
| base_classifier = cls.get_classifier(net_name) | |
| in_features = base_classifier.head.in_features | |
| for label_name, num_outputs in num_outputs_for_label.items(): | |
| classifiers[label_name] = nn.Sequential( | |
| OrderedDict([ | |
| ('head', nn.Linear(in_features, num_outputs)) | |
| ]) | |
| ) | |
| else: | |
| raise ValueError(f"No specified net: {net_name}.") | |
| multi_classifier = nn.ModuleDict(classifiers) | |
| return multi_classifier | |
| def get_classifier_in_features(cls, net_name: str) -> int: | |
| """ | |
| Return in_feature of network indicating by net_name. | |
| This class is used in class MultiNetFusion() only. | |
| Args: | |
| net_name (str): net_name | |
| Returns: | |
| int : in_feature | |
| Required: | |
| classifier.in_feature | |
| classifier.[1].in_features | |
| classifier.[2].in_features | |
| classifier.head.in_features | |
| """ | |
| if net_name == 'MLP': | |
| in_features = cls.mlp_config['hidden_channels'][-1] | |
| elif net_name.startswith('ResNet') or net_name.startswith('DenseNet'): | |
| base_classifier = cls.get_classifier(net_name) | |
| in_features = base_classifier.in_features | |
| elif net_name.startswith('EfficientNet'): | |
| base_classifier = cls.get_classifier(net_name) | |
| in_features = base_classifier[1].in_features | |
| elif net_name.startswith('ConvNeXt'): | |
| base_classifier = cls.get_classifier(net_name) | |
| in_features = base_classifier[2].in_features | |
| elif net_name.startswith('ViT'): | |
| base_classifier = cls.get_classifier(net_name) | |
| in_features = base_classifier.head.in_features | |
| else: | |
| raise ValueError(f"No specified net: {net_name}.") | |
| return in_features | |
| def construct_aux_module(cls, net_name: str) -> nn.Sequential: | |
| """ | |
| Construct module to align the shape of feature from extractor depending on network. | |
| Actually, only when net_name == 'ConvNeXt'. | |
| Because ConvNeXt has the process of aligning the dimensions in its classifier. | |
| Needs to align shape of the feature extractor when ConvNeXt | |
| (classifier): | |
| Sequential( | |
| (0): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True) | |
| (1): Flatten(start_dim=1, end_dim=-1) | |
| (2): Linear(in_features=768, out_features=1000, bias=True) | |
| ) | |
| Args: | |
| net_name (str): net name | |
| Returns: | |
| nn.Module: layers such that they align the dimension of the output from the extractor like the original ConvNeXt. | |
| """ | |
| aux_module = cls.DUMMY | |
| if net_name.startswith('ConvNeXt'): | |
| base_classifier = cls.get_classifier(net_name) | |
| layer_norm = base_classifier[0] | |
| flatten = base_classifier[1] | |
| aux_module = nn.Sequential( | |
| layer_norm, | |
| flatten | |
| ) | |
| return aux_module | |
| def get_last_extractor(cls, net: nn.Module = None, mlp: str = None, net_name: str = None) -> nn.Module: | |
| """ | |
| Return the last extractor of network. | |
| This is for Grad-CAM. | |
| net should be one loaded weight. | |
| Args: | |
| net (nn.Module): network itself | |
| mlp (str): 'MLP', otherwise None | |
| net_name (str): network name | |
| Returns: | |
| nn.Module: last extractor of network | |
| """ | |
| assert (net_name is not None), f"Network does not contain CNN or ViT: mlp={mlp}, net={net_name}." | |
| _extractor = net.extractor_net | |
| if net_name.startswith('ResNet'): | |
| last_extractor = _extractor.layer4[-1] | |
| elif net_name.startswith('DenseNet'): | |
| last_extractor = _extractor.features.denseblock4.denselayer24 | |
| elif net_name.startswith('EfficientNet'): | |
| last_extractor = _extractor.features[-1] | |
| elif net_name.startswith('ConvNeXt'): | |
| last_extractor = _extractor.features[-1][-1].block | |
| elif net_name.startswith('ViT'): | |
| last_extractor = _extractor.encoder.layers[-1] | |
| else: | |
| raise ValueError(f"Cannot get last extractor of net: {net_name}.") | |
| return last_extractor | |
| class MultiMixin: | |
| """ | |
| Class to define auxiliary function to handle multi-label. | |
| """ | |
| def multi_forward(self, out_features: int) -> Dict[str, float]: | |
| """ | |
| Forward out_features to classifier for each label. | |
| Args: | |
| out_features (int): output from extractor | |
| Returns: | |
| Dict[str, float]: output of classifier of each label | |
| """ | |
| output = dict() | |
| for label_name, classifier in self.multi_classifier.items(): | |
| output[label_name] = classifier(out_features) | |
| return output | |
| class MultiWidget(nn.Module, BaseNet, MultiMixin): | |
| """ | |
| Class for a widget to inherit multiple classes simultaneously. | |
| """ | |
| pass | |
| class MultiNet(MultiWidget): | |
| """ | |
| Model of MLP, CNN or ViT. | |
| """ | |
| def __init__( | |
| self, | |
| net_name: str = None, | |
| num_outputs_for_label: Dict[str, int] = None, | |
| mlp_num_inputs: int = None, | |
| in_channel: int = None, | |
| vit_image_size: int = None, | |
| pretrained: bool = None | |
| ) -> None: | |
| """ | |
| Args: | |
| net_name (str): MLP, CNN or ViT name | |
| num_outputs_for_label (Dict[str, int]): number of classes for each label | |
| mlp_num_inputs (int): number of input of MLP. | |
| in_channel (int): number of image channel, ie gray scale(=1) or color image(=3). | |
| vit_image_size (int): image size to be input to ViT. | |
| pretrained (bool): True when use pretrained CNN or ViT, otherwise False. | |
| """ | |
| super().__init__() | |
| self.net_name = net_name | |
| self.num_outputs_for_label = num_outputs_for_label | |
| self.mlp_num_inputs = mlp_num_inputs | |
| self.in_channel = in_channel | |
| self.vit_image_size = vit_image_size | |
| self.pretrained = pretrained | |
| # self.extractor_net = MLP or CVmodel | |
| self.extractor_net = self.construct_extractor( | |
| net_name=self.net_name, | |
| mlp_num_inputs=self.mlp_num_inputs, | |
| in_channel=self.in_channel, | |
| vit_image_size=self.vit_image_size, | |
| pretrained=self.pretrained | |
| ) | |
| self.multi_classifier = self.construct_multi_classifier(net_name=self.net_name, num_outputs_for_label=self.num_outputs_for_label) | |
| def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: | |
| """ | |
| Forward. | |
| Args: | |
| x (torch.Tensor): tabular data or image | |
| Returns: | |
| Dict[str, torch.Tensor]: output | |
| """ | |
| out_features = self.extractor_net(x) | |
| output = self.multi_forward(out_features) | |
| return output | |
| class MultiNetFusion(MultiWidget): | |
| """ | |
| Fusion model of MLP and CNN or ViT. | |
| """ | |
| def __init__( | |
| self, | |
| net_name: str = None, | |
| num_outputs_for_label: Dict[str, int] = None, | |
| mlp_num_inputs: int = None, | |
| in_channel: int = None, | |
| vit_image_size: int = None, | |
| pretrained: bool = None | |
| ) -> None: | |
| """ | |
| Args: | |
| net_name (str): CNN or ViT name. It is clear that MLP is used in fusion model. | |
| num_outputs_for_label (Dict[str, int]): number of classes for each label | |
| mlp_num_inputs (int): number of input of MLP. Defaults to None. | |
| in_channel (int): number of image channel, ie gray scale(=1) or color image(=3). | |
| vit_image_size (int): image size to be input to ViT. | |
| pretrained (bool): True when use pretrained CNN or ViT, otherwise False. | |
| """ | |
| assert (net_name != 'MLP'), 'net_name should not be MLP.' | |
| super().__init__() | |
| self.net_name = net_name | |
| self.num_outputs_for_label = num_outputs_for_label | |
| self.mlp_num_inputs = mlp_num_inputs | |
| self.in_channel = in_channel | |
| self.vit_image_size = vit_image_size | |
| self.pretrained = pretrained | |
| # Extractor of MLP and Net | |
| self.extractor_mlp = self.construct_extractor(net_name='MLP', mlp_num_inputs=self.mlp_num_inputs) | |
| self.extractor_net = self.construct_extractor( | |
| net_name=self.net_name, | |
| in_channel=self.in_channel, | |
| vit_image_size=self.vit_image_size, | |
| pretrained=self.pretrained | |
| ) | |
| self.aux_module = self.construct_aux_module(self.net_name) | |
| # Intermediate MLP | |
| self.in_features_from_mlp = self.get_classifier_in_features('MLP') | |
| self.in_features_from_net = self.get_classifier_in_features(self.net_name) | |
| self.inter_mlp_in_feature = self.in_features_from_mlp + self.in_features_from_net | |
| self.inter_mlp = self.MLPNet(mlp_num_inputs=self.inter_mlp_in_feature, inplace=False) | |
| # Multi classifier | |
| self.multi_classifier = self.construct_multi_classifier(net_name='MLP', num_outputs_for_label=num_outputs_for_label) | |
| def forward(self, x_mlp: torch.Tensor, x_net: torch.Tensor) -> Dict[str, torch.Tensor]: | |
| """ | |
| Forward. | |
| Args: | |
| x_mlp (torch.Tensor): tabular data | |
| x_net (torch.Tensor): image | |
| Returns: | |
| Dict[str, torch.Tensor]: output | |
| """ | |
| out_mlp = self.extractor_mlp(x_mlp) | |
| out_net = self.extractor_net(x_net) | |
| out_net = self.aux_module(out_net) | |
| out_features = torch.cat([out_mlp, out_net], dim=1) | |
| out_features = self.inter_mlp(out_features) | |
| output = self.multi_forward(out_features) | |
| return output | |
| def create_net( | |
| mlp: Optional[str] = None, | |
| net: Optional[str] = None, | |
| num_outputs_for_label: Dict[str, int] = None, | |
| mlp_num_inputs: int = None, | |
| in_channel: int = None, | |
| vit_image_size: int = None, | |
| pretrained: bool = None | |
| ) -> nn.Module: | |
| """ | |
| Create network. | |
| Args: | |
| mlp (Optional[str]): 'MLP' or None | |
| net (Optional[str]): CNN, ViT name or None | |
| num_outputs_for_label (Dict[str, int]): number of outputs for each label | |
| mlp_num_inputs (int): number of input of MLP. | |
| in_channel (int): number of image channel, ie gray scale(=1) or color image(=3). | |
| vit_image_size (int): image size to be input to ViT. | |
| pretrained (bool): True when use pretrained CNN or ViT, otherwise False. | |
| Returns: | |
| nn.Module: network | |
| """ | |
| _isMLPModel = (mlp is not None) and (net is None) | |
| _isCVModel = (mlp is None) and (net is not None) | |
| _isFusion = (mlp is not None) and (net is not None) | |
| if _isMLPModel: | |
| multi_net = MultiNet( | |
| net_name='MLP', | |
| num_outputs_for_label=num_outputs_for_label, | |
| mlp_num_inputs=mlp_num_inputs, | |
| in_channel=in_channel, | |
| vit_image_size=vit_image_size, | |
| pretrained=False # No need of pretrained for MLP | |
| ) | |
| elif _isCVModel: | |
| multi_net = MultiNet( | |
| net_name=net, | |
| num_outputs_for_label=num_outputs_for_label, | |
| mlp_num_inputs=mlp_num_inputs, | |
| in_channel=in_channel, | |
| vit_image_size=vit_image_size, | |
| pretrained=pretrained | |
| ) | |
| elif _isFusion: | |
| multi_net = MultiNetFusion( | |
| net_name=net, | |
| num_outputs_for_label=num_outputs_for_label, | |
| mlp_num_inputs=mlp_num_inputs, | |
| in_channel=in_channel, | |
| vit_image_size=vit_image_size, | |
| pretrained=pretrained | |
| ) | |
| else: | |
| raise ValueError(f"Invalid model type: mlp={mlp}, net={net}.") | |
| return multi_net | |