"""segmentation_models.pytorch (SMP) backbones. SMP gives us a large architecture x encoder zoo behind one constructor and is a plain nn.Module (DDP- and bf16-friendly). Grayscale (in_channels=1) with ImageNet encoder weights is handled by SMP's built-in channel adaptation. """ from __future__ import annotations import segmentation_models_pytorch as smp # our arch name -> SMP architecture key _SMP_ARCH = { "unet": "Unet", "unetpp": "UnetPlusPlus", "unetplusplus": "UnetPlusPlus", "manet": "MAnet", "linknet": "Linknet", "fpn": "FPN", "pspnet": "PSPNet", "deeplabv3": "DeepLabV3", "deeplabv3plus": "DeepLabV3Plus", "pan": "PAN", } def is_smp_arch(arch: str) -> bool: return arch.lower() in _SMP_ARCH def build_smp(arch: str, in_channels: int, num_classes: int, encoder: str = "resnet34", encoder_weights: str = "imagenet", **_): weights = None if encoder_weights in ("", "none", None) else encoder_weights return smp.create_model( arch=_SMP_ARCH[arch.lower()], encoder_name=encoder, encoder_weights=weights, in_channels=in_channels, classes=num_classes, )