MaybeRichard's picture
Upload folder using huggingface_hub
b8fae22 verified
Raw
History Blame Contribute Delete
1.97 kB
"""Model registry: one entry point to build any in-framework segmenter.
build_model("unet", in_channels=3, num_classes=2, ...)
Architectures:
* SMP zoo: unet, unetpp, manet, linknet, fpn, pspnet, deeplabv3, deeplabv3plus, pan
* attention_unet : reimplemented Oktay attention U-Net
* transunet : sota/TransUNet model def (img_size divisible by 16; 224 canonical)
* swinunet : sota/Swin-Unet model def (img_size must be 224)
nnU-Net, U-Mamba (separate CLIs) and SAM-family (dropped) are intentionally NOT here.
"""
from __future__ import annotations
import torch.nn as nn
from .smp_models import is_smp_arch, build_smp
from .attention_unet import build_attention_unet
from .transunet_wrap import build_transunet
from .swinunet_wrap import build_swinunet
_FIXED_INPUT = {"swinunet": 224} # archs that demand a specific img_size
def required_img_size(arch: str):
return _FIXED_INPUT.get(arch.lower())
def build_model(arch: str, in_channels: int, num_classes: int, img_size: int = 256,
encoder: str = "resnet34", encoder_weights: str = "imagenet",
pretrained_ckpt: str = "") -> nn.Module:
a = arch.lower()
if is_smp_arch(a):
return build_smp(a, in_channels, num_classes,
encoder=encoder, encoder_weights=encoder_weights)
if a == "attention_unet":
return build_attention_unet(in_channels, num_classes)
if a == "transunet":
return build_transunet(in_channels, num_classes, img_size=img_size,
encoder=encoder, pretrained_ckpt=pretrained_ckpt)
if a == "swinunet":
return build_swinunet(in_channels, num_classes, img_size=img_size,
pretrained_ckpt=pretrained_ckpt)
raise ValueError(f"unknown arch '{arch}'. "
f"SMP: unet/unetpp/manet/linknet/fpn/pspnet/deeplabv3/deeplabv3plus/pan; "
f"plus attention_unet/transunet/swinunet.")