| | |
| | try: |
| | import timm |
| | except ImportError: |
| | timm = None |
| |
|
| | from mmengine.model import BaseModule |
| | from mmengine.registry import MODELS as MMENGINE_MODELS |
| |
|
| | from mmseg.registry import MODELS |
| |
|
| |
|
| | @MODELS.register_module() |
| | class TIMMBackbone(BaseModule): |
| | """Wrapper to use backbones from timm library. More details can be found in |
| | `timm <https://github.com/rwightman/pytorch-image-models>`_ . |
| | |
| | Args: |
| | model_name (str): Name of timm model to instantiate. |
| | pretrained (bool): Load pretrained weights if True. |
| | checkpoint_path (str): Path of checkpoint to load after |
| | model is initialized. |
| | in_channels (int): Number of input image channels. Default: 3. |
| | init_cfg (dict, optional): Initialization config dict |
| | **kwargs: Other timm & model specific arguments. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | model_name, |
| | features_only=True, |
| | pretrained=True, |
| | checkpoint_path='', |
| | in_channels=3, |
| | init_cfg=None, |
| | **kwargs, |
| | ): |
| | if timm is None: |
| | raise RuntimeError('timm is not installed') |
| | super().__init__(init_cfg) |
| | if 'norm_layer' in kwargs: |
| | kwargs['norm_layer'] = MMENGINE_MODELS.get(kwargs['norm_layer']) |
| | self.timm_model = timm.create_model( |
| | model_name=model_name, |
| | features_only=features_only, |
| | pretrained=pretrained, |
| | in_chans=in_channels, |
| | checkpoint_path=checkpoint_path, |
| | **kwargs, |
| | ) |
| |
|
| | |
| | self.timm_model.global_pool = None |
| | self.timm_model.fc = None |
| | self.timm_model.classifier = None |
| |
|
| | |
| | if pretrained or checkpoint_path: |
| | self._is_init = True |
| |
|
| | def forward(self, x): |
| | features = self.timm_model(x) |
| | return features |
| |
|