"""Model registry: one entry point to build any in-framework segmenter. build_model("unet", in_channels=3, num_classes=2, ...) Architectures: * SMP zoo: unet, unetpp, manet, linknet, fpn, pspnet, deeplabv3, deeplabv3plus, pan * attention_unet : reimplemented Oktay attention U-Net * transunet : sota/TransUNet model def (img_size divisible by 16; 224 canonical) * swinunet : sota/Swin-Unet model def (img_size must be 224) nnU-Net, U-Mamba (separate CLIs) and SAM-family (dropped) are intentionally NOT here. """ from __future__ import annotations import torch.nn as nn from .smp_models import is_smp_arch, build_smp from .attention_unet import build_attention_unet from .transunet_wrap import build_transunet from .swinunet_wrap import build_swinunet _FIXED_INPUT = {"swinunet": 224} # archs that demand a specific img_size def required_img_size(arch: str): return _FIXED_INPUT.get(arch.lower()) def build_model(arch: str, in_channels: int, num_classes: int, img_size: int = 256, encoder: str = "resnet34", encoder_weights: str = "imagenet", pretrained_ckpt: str = "") -> nn.Module: a = arch.lower() if is_smp_arch(a): return build_smp(a, in_channels, num_classes, encoder=encoder, encoder_weights=encoder_weights) if a == "attention_unet": return build_attention_unet(in_channels, num_classes) if a == "transunet": return build_transunet(in_channels, num_classes, img_size=img_size, encoder=encoder, pretrained_ckpt=pretrained_ckpt) if a == "swinunet": return build_swinunet(in_channels, num_classes, img_size=img_size, pretrained_ckpt=pretrained_ckpt) raise ValueError(f"unknown arch '{arch}'. " f"SMP: unet/unetpp/manet/linknet/fpn/pspnet/deeplabv3/deeplabv3plus/pan; " f"plus attention_unet/transunet/swinunet.")