Spaces:
Running
Running
| from .unet import build_resunet | |
| from .deeplabv3 import build_deeplabv3 | |
| from .vit import build_vit | |
| def build_model( | |
| model_name="resunet", | |
| num_classes=1, | |
| in_channels=3, | |
| image_size=512, | |
| backbone="resnet50", | |
| pretrained=True, | |
| base_channels=32, | |
| dropout=0.0, | |
| ): | |
| """ | |
| Generic model builder. | |
| model_name options: | |
| resunet | |
| deeplabv3 | |
| vit | |
| backbone: | |
| For deeplabv3: | |
| resnet50, resnet101 | |
| For vit: | |
| tiny, small, base, large | |
| or a timm model name | |
| For resunet: | |
| unused | |
| """ | |
| model_name = model_name.lower() | |
| if model_name == "resunet": | |
| return build_resunet( | |
| in_channels=in_channels, | |
| num_classes=num_classes, | |
| base_channels=base_channels, | |
| dropout=dropout, | |
| ) | |
| if model_name == "deeplabv3": | |
| return build_deeplabv3( | |
| backbone=backbone, | |
| num_classes=num_classes, | |
| pretrained_backbone=pretrained, | |
| ) | |
| if model_name == "vit": | |
| return build_vit( | |
| variant=backbone, | |
| num_classes=num_classes, | |
| pretrained=pretrained, | |
| in_chans=in_channels, | |
| img_size=image_size, | |
| dropout=dropout, | |
| ) | |
| raise ValueError( | |
| f"Unsupported model_name: {model_name}. " | |
| "Choose from: resunet, deeplabv3, vit." | |
| ) |