Spaces:
Running
Running
File size: 1,434 Bytes
e99a83c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | from .unet import build_resunet
from .deeplabv3 import build_deeplabv3
from .vit import build_vit
def build_model(
model_name="resunet",
num_classes=1,
in_channels=3,
image_size=512,
backbone="resnet50",
pretrained=True,
base_channels=32,
dropout=0.0,
):
"""
Generic model builder.
model_name options:
resunet
deeplabv3
vit
backbone:
For deeplabv3:
resnet50, resnet101
For vit:
tiny, small, base, large
or a timm model name
For resunet:
unused
"""
model_name = model_name.lower()
if model_name == "resunet":
return build_resunet(
in_channels=in_channels,
num_classes=num_classes,
base_channels=base_channels,
dropout=dropout,
)
if model_name == "deeplabv3":
return build_deeplabv3(
backbone=backbone,
num_classes=num_classes,
pretrained_backbone=pretrained,
)
if model_name == "vit":
return build_vit(
variant=backbone,
num_classes=num_classes,
pretrained=pretrained,
in_chans=in_channels,
img_size=image_size,
dropout=dropout,
)
raise ValueError(
f"Unsupported model_name: {model_name}. "
"Choose from: resunet, deeplabv3, vit."
) |