MaybeRichard's picture
Upload folder using huggingface_hub
b8fae22 verified
Raw
History Blame Contribute Delete
1.18 kB
"""segmentation_models.pytorch (SMP) backbones.
SMP gives us a large architecture x encoder zoo behind one constructor and is a
plain nn.Module (DDP- and bf16-friendly). Grayscale (in_channels=1) with ImageNet
encoder weights is handled by SMP's built-in channel adaptation.
"""
from __future__ import annotations
import segmentation_models_pytorch as smp
# our arch name -> SMP architecture key
_SMP_ARCH = {
"unet": "Unet",
"unetpp": "UnetPlusPlus",
"unetplusplus": "UnetPlusPlus",
"manet": "MAnet",
"linknet": "Linknet",
"fpn": "FPN",
"pspnet": "PSPNet",
"deeplabv3": "DeepLabV3",
"deeplabv3plus": "DeepLabV3Plus",
"pan": "PAN",
}
def is_smp_arch(arch: str) -> bool:
return arch.lower() in _SMP_ARCH
def build_smp(arch: str, in_channels: int, num_classes: int,
encoder: str = "resnet34", encoder_weights: str = "imagenet", **_):
weights = None if encoder_weights in ("", "none", None) else encoder_weights
return smp.create_model(
arch=_SMP_ARCH[arch.lower()],
encoder_name=encoder,
encoder_weights=weights,
in_channels=in_channels,
classes=num_classes,
)