File size: 1,434 Bytes
e99a83c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from .unet import build_resunet
from .deeplabv3 import build_deeplabv3
from .vit import build_vit


def build_model(
    model_name="resunet",
    num_classes=1,
    in_channels=3,
    image_size=512,
    backbone="resnet50",
    pretrained=True,
    base_channels=32,
    dropout=0.0,
):
    """
    Generic model builder.

    model_name options:
        resunet
        deeplabv3
        vit

    backbone:
        For deeplabv3:
            resnet50, resnet101

        For vit:
            tiny, small, base, large
            or a timm model name

        For resunet:
            unused
    """

    model_name = model_name.lower()

    if model_name == "resunet":
        return build_resunet(
            in_channels=in_channels,
            num_classes=num_classes,
            base_channels=base_channels,
            dropout=dropout,
        )

    if model_name == "deeplabv3":
        return build_deeplabv3(
            backbone=backbone,
            num_classes=num_classes,
            pretrained_backbone=pretrained,
        )

    if model_name == "vit":
        return build_vit(
            variant=backbone,
            num_classes=num_classes,
            pretrained=pretrained,
            in_chans=in_channels,
            img_size=image_size,
            dropout=dropout,
        )

    raise ValueError(
        f"Unsupported model_name: {model_name}. "
        "Choose from: resunet, deeplabv3, vit."
    )