| """segmentation_models.pytorch (SMP) backbones. |
| |
| SMP gives us a large architecture x encoder zoo behind one constructor and is a |
| plain nn.Module (DDP- and bf16-friendly). Grayscale (in_channels=1) with ImageNet |
| encoder weights is handled by SMP's built-in channel adaptation. |
| """ |
| from __future__ import annotations |
|
|
| import segmentation_models_pytorch as smp |
|
|
| |
| _SMP_ARCH = { |
| "unet": "Unet", |
| "unetpp": "UnetPlusPlus", |
| "unetplusplus": "UnetPlusPlus", |
| "manet": "MAnet", |
| "linknet": "Linknet", |
| "fpn": "FPN", |
| "pspnet": "PSPNet", |
| "deeplabv3": "DeepLabV3", |
| "deeplabv3plus": "DeepLabV3Plus", |
| "pan": "PAN", |
| } |
|
|
|
|
| def is_smp_arch(arch: str) -> bool: |
| return arch.lower() in _SMP_ARCH |
|
|
|
|
| def build_smp(arch: str, in_channels: int, num_classes: int, |
| encoder: str = "resnet34", encoder_weights: str = "imagenet", **_): |
| weights = None if encoder_weights in ("", "none", None) else encoder_weights |
| return smp.create_model( |
| arch=_SMP_ARCH[arch.lower()], |
| encoder_name=encoder, |
| encoder_weights=weights, |
| in_channels=in_channels, |
| classes=num_classes, |
| ) |
|
|