| from timm import create_model |
| from torch import nn, Tensor |
| from typing import Optional |
| from functools import partial |
|
|
| from ..utils import _get_activation, _get_norm_layer, ConvUpsample, ConvDownsample |
| from ..utils import LightConvUpsample, LightConvDownsample, LighterConvUpsample, LighterConvDownsample |
| from ..utils import ConvRefine, LightConvRefine, LighterConvRefine |
|
|
| regular_models = [ |
| "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", |
| "convnext_nano", "convnext_tiny", "convnext_small", "convnext_base", |
| "mobilenetv4_conv_large", |
| ] |
|
|
| heavy_models = [ |
| "convnext_large", "convnext_xlarge", "convnext_xxlarge", |
| ] |
|
|
| light_models = [ |
| "mobilenetv1_100", "mobilenetv1_125", |
| "mobilenetv2_100", "mobilenetv2_140", |
| "mobilenetv3_large_100", |
| "mobilenetv4_conv_medium", |
|
|
| ] |
|
|
| lighter_models = [ |
| "mobilenetv2_050", |
| "mobilenetv3_small_050", "mobilenetv3_small_075", "mobilenetv3_small_100", |
| "mobilenetv4_conv_small_050", "mobilenetv4_conv_small" |
| ] |
|
|
| supported_models = regular_models + heavy_models + light_models + lighter_models |
|
|
|
|
| refiner_in_channels = { |
| |
| "resnet18": 512, |
| "resnet34": 512, |
| "resnet50": 2048, |
| "resnet101": 2048, |
| "resnet152": 2048, |
| |
| "convnext_nano": 640, |
| "convnext_tiny": 768, |
| "convnext_small": 768, |
| "convnext_base": 1024, |
| "convnext_large": 1536, |
| "convnext_xlarge": 2048, |
| "convnext_xxlarge": 3072, |
| |
| "mobilenetv1_100": 1024, |
| "mobilenetv1_125": 1280, |
| |
| "mobilenetv2_050": 160, |
| "mobilenetv2_100": 320, |
| "mobilenetv2_140": 448, |
| |
| "mobilenetv3_small_050": 288, |
| "mobilenetv3_small_075": 432, |
| "mobilenetv3_small_100": 576, |
| "mobilenetv3_large_100": 960, |
| |
| "mobilenetv4_conv_small_050": 480, |
| "mobilenetv4_conv_small": 960, |
| "mobilenetv4_conv_medium": 960, |
| "mobilenetv4_conv_large": 960, |
| } |
|
|
|
|
| refiner_out_channels = { |
| |
| "resnet18": 512, |
| "resnet34": 512, |
| "resnet50": 2048, |
| "resnet101": 2048, |
| "resnet152": 2048, |
| |
| "convnext_nano": 640, |
| "convnext_tiny": 768, |
| "convnext_small": 768, |
| "convnext_base": 1024, |
| "convnext_large": 1536, |
| "convnext_xlarge": 2048, |
| "convnext_xxlarge": 3072, |
| |
| "mobilenetv1_100": 512, |
| "mobilenetv1_125": 640, |
| |
| "mobilenetv2_050": 160, |
| "mobilenetv2_100": 320, |
| "mobilenetv2_140": 448, |
| |
| "mobilenetv3_small_050": 288, |
| "mobilenetv3_small_075": 432, |
| "mobilenetv3_small_100": 576, |
| "mobilenetv3_large_100": 480, |
| |
| "mobilenetv4_conv_small_050": 480, |
| "mobilenetv4_conv_small": 960, |
| "mobilenetv4_conv_medium": 960, |
| "mobilenetv4_conv_large": 960, |
| } |
|
|
|
|
| groups = { |
| |
| "resnet18": 1, |
| "resnet34": 1, |
| "resnet50": refiner_in_channels["resnet50"] // 512, |
| "resnet101": refiner_in_channels["resnet101"] // 512, |
| "resnet152": refiner_in_channels["resnet152"] // 512, |
| |
| "convnext_nano": 8, |
| "convnext_tiny": 8, |
| "convnext_small": 8, |
| "convnext_base": 8, |
| "convnext_large": refiner_in_channels["convnext_large"] // 512, |
| "convnext_xlarge": refiner_in_channels["convnext_xlarge"] // 512, |
| "convnext_xxlarge": refiner_in_channels["convnext_xxlarge"] // 512, |
| |
| "mobilenetv1_100": None, |
| "mobilenetv1_125": None, |
| |
| "mobilenetv2_050": None, |
| "mobilenetv2_100": None, |
| "mobilenetv2_140": None, |
| |
| "mobilenetv3_small_050": None, |
| "mobilenetv3_small_075": None, |
| "mobilenetv3_small_100": None, |
| "mobilenetv3_large_100": None, |
| |
| "mobilenetv4_conv_small_050": None, |
| "mobilenetv4_conv_small": None, |
| "mobilenetv4_conv_medium": None, |
| "mobilenetv4_conv_large": 1, |
| } |
|
|
|
|
| class TIMMModel(nn.Module): |
| def __init__( |
| self, |
| model_name: str, |
| block_size: Optional[int] = None, |
| norm: str = "none", |
| act: str = "none" |
| ) -> None: |
| super().__init__() |
| assert model_name in supported_models, f"Backbone {model_name} not supported. Supported models are {supported_models}" |
| assert block_size is None or block_size in [8, 16, 32], f"Block size should be one of [8, 16, 32], but got {block_size}." |
| self.model_name = model_name |
| self.encoder = create_model(model_name, pretrained=True, features_only=True, out_indices=[-1]) |
| self.encoder_channels = self.encoder.feature_info.channels()[-1] |
| self.encoder_reduction = self.encoder.feature_info.reduction()[-1] |
| self.block_size = block_size if block_size is not None else self.encoder_reduction |
|
|
| if model_name in lighter_models: |
| upsample_block = LighterConvUpsample |
| downsample_block = LighterConvDownsample |
| decoder_block = LighterConvRefine |
| elif model_name in light_models: |
| upsample_block = LightConvUpsample |
| downsample_block = LightConvDownsample |
| decoder_block = LightConvRefine |
| else: |
| upsample_block = partial(ConvUpsample, groups=groups[model_name]) |
| downsample_block = partial(ConvDownsample, groups=groups[model_name]) |
| decoder_block = partial(ConvRefine, groups=groups[model_name]) |
|
|
| |
| if norm == "bn": |
| norm_layer = nn.BatchNorm2d |
| elif norm == "ln": |
| norm_layer = nn.LayerNorm |
| else: |
| norm_layer = _get_norm_layer(self.encoder) |
| |
| if act == "relu": |
| activation = nn.ReLU(inplace=True) |
| elif act == "gelu": |
| activation = nn.GELU() |
| else: |
| activation = _get_activation(self.encoder) |
| |
| if self.block_size > self.encoder_reduction: |
| if self.block_size > self.encoder_reduction * 2: |
| assert self.block_size == self.encoder_reduction * 4, f"Block size {self.block_size} is not supported for model {self.model_name}. Supported block sizes are {self.encoder_reduction}, {self.encoder_reduction * 2}, and {self.encoder_reduction * 4}." |
| self.refiner = nn.Sequential( |
| downsample_block( |
| in_channels=self.encoder_channels, |
| out_channels=refiner_in_channels[self.model_name], |
| norm_layer=norm_layer, |
| activation=activation, |
| ), |
| downsample_block( |
| in_channels=refiner_in_channels[self.model_name], |
| out_channels=refiner_out_channels[self.model_name], |
| norm_layer=norm_layer, |
| activation=activation, |
| ) |
| ) |
| else: |
| assert self.block_size == self.encoder_reduction * 2, f"Block size {self.block_size} is not supported for model {self.model_name}. Supported block sizes are {self.encoder_reduction}, {self.encoder_reduction * 2}, and {self.encoder_reduction * 4}." |
| self.refiner = downsample_block( |
| in_channels=self.encoder_channels, |
| out_channels=refiner_out_channels[self.model_name], |
| norm_layer=norm_layer, |
| activation=activation, |
| ) |
|
|
| self.refiner_channels = refiner_out_channels[self.model_name] |
| |
| elif self.block_size < self.encoder_reduction: |
| if self.block_size < self.encoder_reduction // 2: |
| assert self.block_size == self.encoder_reduction // 4, f"Block size {self.block_size} is not supported for model {self.model_name}. Supported block sizes are {self.encoder_reduction}, {self.encoder_reduction // 2}, and {self.encoder_reduction // 4}." |
| self.refiner = nn.Sequential( |
| upsample_block( |
| in_channels=self.encoder_channels, |
| out_channels=refiner_in_channels[self.model_name], |
| norm_layer=norm_layer, |
| activation=activation, |
| ), |
| upsample_block( |
| in_channels=refiner_in_channels[self.model_name], |
| out_channels=refiner_out_channels[self.model_name], |
| norm_layer=norm_layer, |
| activation=activation, |
| ) |
| ) |
| else: |
| assert self.block_size == self.encoder_reduction // 2, f"Block size {self.block_size} is not supported for model {self.model_name}. Supported block sizes are {self.encoder_reduction}, {self.encoder_reduction // 2}, and {self.encoder_reduction // 4}." |
| self.refiner = upsample_block( |
| in_channels=self.encoder_channels, |
| out_channels=refiner_out_channels[self.model_name], |
| norm_layer=norm_layer, |
| activation=activation, |
| ) |
| |
| self.refiner_channels = refiner_out_channels[self.model_name] |
| |
| else: |
| self.refiner = nn.Identity() |
| self.refiner_channels = self.encoder_channels |
|
|
| self.refiner_reduction = self.block_size |
| |
| if self.refiner_channels <= 256: |
| self.decoder = nn.Identity() |
| self.decoder_channels = self.refiner_channels |
| elif self.refiner_channels <= 512: |
| self.decoder = decoder_block( |
| in_channels=self.refiner_channels, |
| out_channels=self.refiner_channels // 2, |
| norm_layer=norm_layer, |
| activation=activation, |
| ) |
| self.decoder_channels = self.refiner_channels // 2 |
| elif self.refiner_channels <= 1024: |
| self.decoder = nn.Sequential( |
| decoder_block( |
| in_channels=self.refiner_channels, |
| out_channels=self.refiner_channels // 2, |
| norm_layer=norm_layer, |
| activation=activation, |
| ), |
| decoder_block( |
| in_channels=self.refiner_channels // 2, |
| out_channels=self.refiner_channels // 4, |
| norm_layer=norm_layer, |
| activation=activation, |
| ), |
| ) |
| self.decoder_channels = self.refiner_channels // 4 |
| else: |
| self.decoder = nn.Sequential( |
| decoder_block( |
| in_channels=self.refiner_channels, |
| out_channels=self.refiner_channels // 2, |
| norm_layer=norm_layer, |
| activation=activation, |
| ), |
| decoder_block( |
| in_channels=self.refiner_channels // 2, |
| out_channels=self.refiner_channels // 4, |
| norm_layer=norm_layer, |
| activation=activation, |
| ), |
| decoder_block( |
| in_channels=self.refiner_channels // 4, |
| out_channels=self.refiner_channels // 8, |
| norm_layer=norm_layer, |
| activation=activation, |
| ), |
| ) |
| self.decoder_channels = self.refiner_channels // 8 |
|
|
| self.decoder_reduction = self.refiner_reduction |
|
|
| def encode(self, x: Tensor) -> Tensor: |
| return self.encoder(x)[0] |
| |
| def refine(self, x: Tensor) -> Tensor: |
| return self.refiner(x) |
| |
| def decode(self, x: Tensor) -> Tensor: |
| return self.decoder(x) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| x = self.encode(x) |
| x = self.refine(x) |
| x = self.decode(x) |
| return x |
|
|
|
|
| def _timm_model(model_name: str, block_size: Optional[int] = None, norm: str = "none", act: str = "none") -> TIMMModel: |
| return TIMMModel(model_name, block_size=block_size, norm=norm, act=act) |
|
|