| """Model registry: one entry point to build any in-framework segmenter. |
| |
| build_model("unet", in_channels=3, num_classes=2, ...) |
| |
| Architectures: |
| * SMP zoo: unet, unetpp, manet, linknet, fpn, pspnet, deeplabv3, deeplabv3plus, pan |
| * attention_unet : reimplemented Oktay attention U-Net |
| * transunet : sota/TransUNet model def (img_size divisible by 16; 224 canonical) |
| * swinunet : sota/Swin-Unet model def (img_size must be 224) |
| |
| nnU-Net, U-Mamba (separate CLIs) and SAM-family (dropped) are intentionally NOT here. |
| """ |
| from __future__ import annotations |
|
|
| import torch.nn as nn |
|
|
| from .smp_models import is_smp_arch, build_smp |
| from .attention_unet import build_attention_unet |
| from .transunet_wrap import build_transunet |
| from .swinunet_wrap import build_swinunet |
|
|
| _FIXED_INPUT = {"swinunet": 224} |
|
|
|
|
| def required_img_size(arch: str): |
| return _FIXED_INPUT.get(arch.lower()) |
|
|
|
|
| def build_model(arch: str, in_channels: int, num_classes: int, img_size: int = 256, |
| encoder: str = "resnet34", encoder_weights: str = "imagenet", |
| pretrained_ckpt: str = "") -> nn.Module: |
| a = arch.lower() |
| if is_smp_arch(a): |
| return build_smp(a, in_channels, num_classes, |
| encoder=encoder, encoder_weights=encoder_weights) |
| if a == "attention_unet": |
| return build_attention_unet(in_channels, num_classes) |
| if a == "transunet": |
| return build_transunet(in_channels, num_classes, img_size=img_size, |
| encoder=encoder, pretrained_ckpt=pretrained_ckpt) |
| if a == "swinunet": |
| return build_swinunet(in_channels, num_classes, img_size=img_size, |
| pretrained_ckpt=pretrained_ckpt) |
| raise ValueError(f"unknown arch '{arch}'. " |
| f"SMP: unet/unetpp/manet/linknet/fpn/pspnet/deeplabv3/deeplabv3plus/pan; " |
| f"plus attention_unet/transunet/swinunet.") |
|
|