|
|
|
|
|
import logging |
|
|
from abc import ABCMeta, abstractmethod |
|
|
|
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
from mmcv_custom.checkpoint import load_checkpoint |
|
|
|
|
|
class BaseBackbone(nn.Module, metaclass=ABCMeta): |
|
|
"""Base backbone. |
|
|
|
|
|
This class defines the basic functions of a backbone. Any backbone that |
|
|
inherits this class should at least define its own `forward` function. |
|
|
""" |
|
|
|
|
|
def init_weights(self, pretrained=None, patch_padding='pad', part_features=None): |
|
|
"""Init backbone weights. |
|
|
|
|
|
Args: |
|
|
pretrained (str | None): If pretrained is a string, then it |
|
|
initializes backbone weights by loading the pretrained |
|
|
checkpoint. If pretrained is None, then it follows default |
|
|
initializer or customized initializer in subclasses. |
|
|
""" |
|
|
if isinstance(pretrained, str): |
|
|
logger = logging.getLogger() |
|
|
load_checkpoint(self, pretrained, strict=False, logger=logger, patch_padding=patch_padding, part_features=part_features) |
|
|
elif pretrained is None: |
|
|
|
|
|
pass |
|
|
else: |
|
|
raise TypeError('pretrained must be a str or None.' |
|
|
f' But received {type(pretrained)}.') |
|
|
|
|
|
@abstractmethod |
|
|
def forward(self, x): |
|
|
"""Forward function. |
|
|
|
|
|
Args: |
|
|
x (Tensor | tuple[Tensor]): x could be a torch.Tensor or a tuple of |
|
|
torch.Tensor, containing input data for forward computation. |
|
|
""" |
|
|
|