from .unet import build_resunet from .deeplabv3 import build_deeplabv3 from .vit import build_vit def build_model( model_name="resunet", num_classes=1, in_channels=3, image_size=512, backbone="resnet50", pretrained=True, base_channels=32, dropout=0.0, ): """ Generic model builder. model_name options: resunet deeplabv3 vit backbone: For deeplabv3: resnet50, resnet101 For vit: tiny, small, base, large or a timm model name For resunet: unused """ model_name = model_name.lower() if model_name == "resunet": return build_resunet( in_channels=in_channels, num_classes=num_classes, base_channels=base_channels, dropout=dropout, ) if model_name == "deeplabv3": return build_deeplabv3( backbone=backbone, num_classes=num_classes, pretrained_backbone=pretrained, ) if model_name == "vit": return build_vit( variant=backbone, num_classes=num_classes, pretrained=pretrained, in_chans=in_channels, img_size=image_size, dropout=dropout, ) raise ValueError( f"Unsupported model_name: {model_name}. " "Choose from: resunet, deeplabv3, vit." )