| | from timm import create_model |
| | from torch import nn, Tensor |
| | from typing import Optional |
| | from functools import partial |
| |
|
| | from ..utils import _get_activation, _get_norm_layer, ConvUpsample, ConvDownsample |
| | from ..utils import LightConvUpsample, LightConvDownsample, LighterConvUpsample, LighterConvDownsample |
| | from ..utils import ConvRefine, LightConvRefine, LighterConvRefine |
| |
|
| | regular_models = [ |
| | "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", |
| | "convnext_nano", "convnext_tiny", "convnext_small", "convnext_base", |
| | "mobilenetv4_conv_large", |
| | ] |
| |
|
| | heavy_models = [ |
| | "convnext_large", "convnext_xlarge", "convnext_xxlarge", |
| | ] |
| |
|
| | light_models = [ |
| | "mobilenetv1_100", "mobilenetv1_125", |
| | "mobilenetv2_100", "mobilenetv2_140", |
| | "mobilenetv3_large_100", |
| | "mobilenetv4_conv_medium", |
| |
|
| | ] |
| |
|
| | lighter_models = [ |
| | "mobilenetv2_050", |
| | "mobilenetv3_small_050", "mobilenetv3_small_075", "mobilenetv3_small_100", |
| | "mobilenetv4_conv_small_050", "mobilenetv4_conv_small" |
| | ] |
| |
|
| | supported_models = regular_models + heavy_models + light_models + lighter_models |
| |
|
| |
|
| | refiner_in_channels = { |
| | |
| | "resnet18": 512, |
| | "resnet34": 512, |
| | "resnet50": 2048, |
| | "resnet101": 2048, |
| | "resnet152": 2048, |
| | |
| | "convnext_nano": 640, |
| | "convnext_tiny": 768, |
| | "convnext_small": 768, |
| | "convnext_base": 1024, |
| | "convnext_large": 1536, |
| | "convnext_xlarge": 2048, |
| | "convnext_xxlarge": 3072, |
| | |
| | "mobilenetv1_100": 1024, |
| | "mobilenetv1_125": 1280, |
| | |
| | "mobilenetv2_050": 160, |
| | "mobilenetv2_100": 320, |
| | "mobilenetv2_140": 448, |
| | |
| | "mobilenetv3_small_050": 288, |
| | "mobilenetv3_small_075": 432, |
| | "mobilenetv3_small_100": 576, |
| | "mobilenetv3_large_100": 960, |
| | |
| | "mobilenetv4_conv_small_050": 480, |
| | "mobilenetv4_conv_small": 960, |
| | "mobilenetv4_conv_medium": 960, |
| | "mobilenetv4_conv_large": 960, |
| | } |
| |
|
| |
|
| | refiner_out_channels = { |
| | |
| | "resnet18": 512, |
| | "resnet34": 512, |
| | "resnet50": 2048, |
| | "resnet101": 2048, |
| | "resnet152": 2048, |
| | |
| | "convnext_nano": 640, |
| | "convnext_tiny": 768, |
| | "convnext_small": 768, |
| | "convnext_base": 1024, |
| | "convnext_large": 1536, |
| | "convnext_xlarge": 2048, |
| | "convnext_xxlarge": 3072, |
| | |
| | "mobilenetv1_100": 512, |
| | "mobilenetv1_125": 640, |
| | |
| | "mobilenetv2_050": 160, |
| | "mobilenetv2_100": 320, |
| | "mobilenetv2_140": 448, |
| | |
| | "mobilenetv3_small_050": 288, |
| | "mobilenetv3_small_075": 432, |
| | "mobilenetv3_small_100": 576, |
| | "mobilenetv3_large_100": 480, |
| | |
| | "mobilenetv4_conv_small_050": 480, |
| | "mobilenetv4_conv_small": 960, |
| | "mobilenetv4_conv_medium": 960, |
| | "mobilenetv4_conv_large": 960, |
| | } |
| |
|
| |
|
| | groups = { |
| | |
| | "resnet18": 1, |
| | "resnet34": 1, |
| | "resnet50": refiner_in_channels["resnet50"] // 512, |
| | "resnet101": refiner_in_channels["resnet101"] // 512, |
| | "resnet152": refiner_in_channels["resnet152"] // 512, |
| | |
| | "convnext_nano": 8, |
| | "convnext_tiny": 8, |
| | "convnext_small": 8, |
| | "convnext_base": 8, |
| | "convnext_large": refiner_in_channels["convnext_large"] // 512, |
| | "convnext_xlarge": refiner_in_channels["convnext_xlarge"] // 512, |
| | "convnext_xxlarge": refiner_in_channels["convnext_xxlarge"] // 512, |
| | |
| | "mobilenetv1_100": None, |
| | "mobilenetv1_125": None, |
| | |
| | "mobilenetv2_050": None, |
| | "mobilenetv2_100": None, |
| | "mobilenetv2_140": None, |
| | |
| | "mobilenetv3_small_050": None, |
| | "mobilenetv3_small_075": None, |
| | "mobilenetv3_small_100": None, |
| | "mobilenetv3_large_100": None, |
| | |
| | "mobilenetv4_conv_small_050": None, |
| | "mobilenetv4_conv_small": None, |
| | "mobilenetv4_conv_medium": None, |
| | "mobilenetv4_conv_large": 1, |
| | } |
| |
|
| |
|
| | class TIMMModel(nn.Module): |
| | def __init__( |
| | self, |
| | model_name: str, |
| | block_size: Optional[int] = None, |
| | norm: str = "none", |
| | act: str = "none" |
| | ) -> None: |
| | super().__init__() |
| | assert model_name in supported_models, f"Backbone {model_name} not supported. Supported models are {supported_models}" |
| | assert block_size is None or block_size in [8, 16, 32], f"Block size should be one of [8, 16, 32], but got {block_size}." |
| | self.model_name = model_name |
| | self.encoder = create_model(model_name, pretrained=True, features_only=True, out_indices=[-1]) |
| | self.encoder_channels = self.encoder.feature_info.channels()[-1] |
| | self.encoder_reduction = self.encoder.feature_info.reduction()[-1] |
| | self.block_size = block_size if block_size is not None else self.encoder_reduction |
| |
|
| | if model_name in lighter_models: |
| | upsample_block = LighterConvUpsample |
| | downsample_block = LighterConvDownsample |
| | decoder_block = LighterConvRefine |
| | elif model_name in light_models: |
| | upsample_block = LightConvUpsample |
| | downsample_block = LightConvDownsample |
| | decoder_block = LightConvRefine |
| | else: |
| | upsample_block = partial(ConvUpsample, groups=groups[model_name]) |
| | downsample_block = partial(ConvDownsample, groups=groups[model_name]) |
| | decoder_block = partial(ConvRefine, groups=groups[model_name]) |
| |
|
| | |
| | if norm == "bn": |
| | norm_layer = nn.BatchNorm2d |
| | elif norm == "ln": |
| | norm_layer = nn.LayerNorm |
| | else: |
| | norm_layer = _get_norm_layer(self.encoder) |
| | |
| | if act == "relu": |
| | activation = nn.ReLU(inplace=True) |
| | elif act == "gelu": |
| | activation = nn.GELU() |
| | else: |
| | activation = _get_activation(self.encoder) |
| | |
| | if self.block_size > self.encoder_reduction: |
| | if self.block_size > self.encoder_reduction * 2: |
| | assert self.block_size == self.encoder_reduction * 4, f"Block size {self.block_size} is not supported for model {self.model_name}. Supported block sizes are {self.encoder_reduction}, {self.encoder_reduction * 2}, and {self.encoder_reduction * 4}." |
| | self.refiner = nn.Sequential( |
| | downsample_block( |
| | in_channels=self.encoder_channels, |
| | out_channels=refiner_in_channels[self.model_name], |
| | norm_layer=norm_layer, |
| | activation=activation, |
| | ), |
| | downsample_block( |
| | in_channels=refiner_in_channels[self.model_name], |
| | out_channels=refiner_out_channels[self.model_name], |
| | norm_layer=norm_layer, |
| | activation=activation, |
| | ) |
| | ) |
| | else: |
| | assert self.block_size == self.encoder_reduction * 2, f"Block size {self.block_size} is not supported for model {self.model_name}. Supported block sizes are {self.encoder_reduction}, {self.encoder_reduction * 2}, and {self.encoder_reduction * 4}." |
| | self.refiner = downsample_block( |
| | in_channels=self.encoder_channels, |
| | out_channels=refiner_out_channels[self.model_name], |
| | norm_layer=norm_layer, |
| | activation=activation, |
| | ) |
| |
|
| | self.refiner_channels = refiner_out_channels[self.model_name] |
| | |
| | elif self.block_size < self.encoder_reduction: |
| | if self.block_size < self.encoder_reduction // 2: |
| | assert self.block_size == self.encoder_reduction // 4, f"Block size {self.block_size} is not supported for model {self.model_name}. Supported block sizes are {self.encoder_reduction}, {self.encoder_reduction // 2}, and {self.encoder_reduction // 4}." |
| | self.refiner = nn.Sequential( |
| | upsample_block( |
| | in_channels=self.encoder_channels, |
| | out_channels=refiner_in_channels[self.model_name], |
| | norm_layer=norm_layer, |
| | activation=activation, |
| | ), |
| | upsample_block( |
| | in_channels=refiner_in_channels[self.model_name], |
| | out_channels=refiner_out_channels[self.model_name], |
| | norm_layer=norm_layer, |
| | activation=activation, |
| | ) |
| | ) |
| | else: |
| | assert self.block_size == self.encoder_reduction // 2, f"Block size {self.block_size} is not supported for model {self.model_name}. Supported block sizes are {self.encoder_reduction}, {self.encoder_reduction // 2}, and {self.encoder_reduction // 4}." |
| | self.refiner = upsample_block( |
| | in_channels=self.encoder_channels, |
| | out_channels=refiner_out_channels[self.model_name], |
| | norm_layer=norm_layer, |
| | activation=activation, |
| | ) |
| | |
| | self.refiner_channels = refiner_out_channels[self.model_name] |
| | |
| | else: |
| | self.refiner = nn.Identity() |
| | self.refiner_channels = self.encoder_channels |
| |
|
| | self.refiner_reduction = self.block_size |
| | |
| | if self.refiner_channels <= 256: |
| | self.decoder = nn.Identity() |
| | self.decoder_channels = self.refiner_channels |
| | elif self.refiner_channels <= 512: |
| | self.decoder = decoder_block( |
| | in_channels=self.refiner_channels, |
| | out_channels=self.refiner_channels // 2, |
| | norm_layer=norm_layer, |
| | activation=activation, |
| | ) |
| | self.decoder_channels = self.refiner_channels // 2 |
| | elif self.refiner_channels <= 1024: |
| | self.decoder = nn.Sequential( |
| | decoder_block( |
| | in_channels=self.refiner_channels, |
| | out_channels=self.refiner_channels // 2, |
| | norm_layer=norm_layer, |
| | activation=activation, |
| | ), |
| | decoder_block( |
| | in_channels=self.refiner_channels // 2, |
| | out_channels=self.refiner_channels // 4, |
| | norm_layer=norm_layer, |
| | activation=activation, |
| | ), |
| | ) |
| | self.decoder_channels = self.refiner_channels // 4 |
| | else: |
| | self.decoder = nn.Sequential( |
| | decoder_block( |
| | in_channels=self.refiner_channels, |
| | out_channels=self.refiner_channels // 2, |
| | norm_layer=norm_layer, |
| | activation=activation, |
| | ), |
| | decoder_block( |
| | in_channels=self.refiner_channels // 2, |
| | out_channels=self.refiner_channels // 4, |
| | norm_layer=norm_layer, |
| | activation=activation, |
| | ), |
| | decoder_block( |
| | in_channels=self.refiner_channels // 4, |
| | out_channels=self.refiner_channels // 8, |
| | norm_layer=norm_layer, |
| | activation=activation, |
| | ), |
| | ) |
| | self.decoder_channels = self.refiner_channels // 8 |
| |
|
| | self.decoder_reduction = self.refiner_reduction |
| |
|
| | def encode(self, x: Tensor) -> Tensor: |
| | return self.encoder(x)[0] |
| | |
| | def refine(self, x: Tensor) -> Tensor: |
| | return self.refiner(x) |
| | |
| | def decode(self, x: Tensor) -> Tensor: |
| | return self.decoder(x) |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | x = self.encode(x) |
| | x = self.refine(x) |
| | x = self.decode(x) |
| | return x |
| |
|
| |
|
| | def _timm_model(model_name: str, block_size: Optional[int] = None, norm: str = "none", act: str = "none") -> TIMMModel: |
| | return TIMMModel(model_name, block_size=block_size, norm=norm, act=act) |
| |
|