CFPVesselSeg / models /__init__.py
farrell236's picture
add src
e99a83c
from .unet import build_resunet
from .deeplabv3 import build_deeplabv3
from .vit import build_vit
def build_model(
model_name="resunet",
num_classes=1,
in_channels=3,
image_size=512,
backbone="resnet50",
pretrained=True,
base_channels=32,
dropout=0.0,
):
"""
Generic model builder.
model_name options:
resunet
deeplabv3
vit
backbone:
For deeplabv3:
resnet50, resnet101
For vit:
tiny, small, base, large
or a timm model name
For resunet:
unused
"""
model_name = model_name.lower()
if model_name == "resunet":
return build_resunet(
in_channels=in_channels,
num_classes=num_classes,
base_channels=base_channels,
dropout=dropout,
)
if model_name == "deeplabv3":
return build_deeplabv3(
backbone=backbone,
num_classes=num_classes,
pretrained_backbone=pretrained,
)
if model_name == "vit":
return build_vit(
variant=backbone,
num_classes=num_classes,
pretrained=pretrained,
in_chans=in_channels,
img_size=image_size,
dropout=dropout,
)
raise ValueError(
f"Unsupported model_name: {model_name}. "
"Choose from: resunet, deeplabv3, vit."
)