Spaces:
Sleeping
Sleeping
| """ | |
| Baseline Model Factory | |
| This module provides a factory function to create baseline segmentation models | |
| using the segmentation-models-pytorch library. | |
| Supported architectures: | |
| - Unet | |
| - UnetPlusPlus | |
| - PSPNet | |
| - PAN | |
| - MAnet | |
| - Linknet | |
| - FPN | |
| - DeepLabV3Plus | |
| - DeepLabV3 | |
| """ | |
| import torch | |
| import segmentation_models_pytorch as smp | |
| def create_model(architecture='Unet', encoder_name='resnet50', in_channels=3, | |
| num_classes=1, encoder_weights='imagenet'): | |
| """ | |
| Create a baseline segmentation model using segmentation-models-pytorch. | |
| Args: | |
| architecture (str): Model architecture name. One of: | |
| 'Unet', 'UnetPlusPlus', 'PSPNet', 'PAN', 'MAnet', | |
| 'Linknet', 'FPN', 'DeepLabV3Plus', 'DeepLabV3' | |
| encoder_name (str): Encoder backbone name (default: 'resnet50') | |
| in_channels (int): Number of input channels (default: 3) | |
| num_classes (int): Number of output classes (default: 1) | |
| encoder_weights (str): Pretrained weights source (default: 'imagenet') | |
| Returns: | |
| nn.Module: Segmentation model on appropriate device | |
| Example: | |
| >>> model = create_model('DeepLabV3Plus', encoder_name='resnet50') | |
| >>> x = torch.randn(1, 3, 448, 448) | |
| >>> output = model(x) | |
| >>> output.shape | |
| torch.Size([1, 1, 448, 448]) | |
| """ | |
| model = getattr(smp, architecture)( | |
| encoder_name=encoder_name, | |
| encoder_weights=encoder_weights, | |
| in_channels=in_channels, | |
| classes=num_classes | |
| ) | |
| return model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) | |
| # List of supported architectures for reference | |
| SUPPORTED_ARCHITECTURES = [ | |
| 'Unet', 'UnetPlusPlus', 'PSPNet', 'PAN', 'MAnet', | |
| 'Linknet', 'FPN', 'DeepLabV3Plus', 'DeepLabV3' | |
| ] | |