Spaces:
Running on Zero
Running on Zero
File size: 1,868 Bytes
84b07bc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 | """
Baseline Model Factory
This module provides a factory function to create baseline segmentation models
using the segmentation-models-pytorch library.
Supported architectures:
- Unet
- UnetPlusPlus
- PSPNet
- PAN
- MAnet
- Linknet
- FPN
- DeepLabV3Plus
- DeepLabV3
"""
import torch
import segmentation_models_pytorch as smp
def create_model(architecture='Unet', encoder_name='resnet50', in_channels=3,
num_classes=1, encoder_weights='imagenet'):
"""
Create a baseline segmentation model using segmentation-models-pytorch.
Args:
architecture (str): Model architecture name. One of:
'Unet', 'UnetPlusPlus', 'PSPNet', 'PAN', 'MAnet',
'Linknet', 'FPN', 'DeepLabV3Plus', 'DeepLabV3'
encoder_name (str): Encoder backbone name (default: 'resnet50')
in_channels (int): Number of input channels (default: 3)
num_classes (int): Number of output classes (default: 1)
encoder_weights (str): Pretrained weights source (default: 'imagenet')
Returns:
nn.Module: Segmentation model on appropriate device
Example:
>>> model = create_model('DeepLabV3Plus', encoder_name='resnet50')
>>> x = torch.randn(1, 3, 448, 448)
>>> output = model(x)
>>> output.shape
torch.Size([1, 1, 448, 448])
"""
model = getattr(smp, architecture)(
encoder_name=encoder_name,
encoder_weights=encoder_weights,
in_channels=in_channels,
classes=num_classes
)
return model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
# List of supported architectures for reference
SUPPORTED_ARCHITECTURES = [
'Unet', 'UnetPlusPlus', 'PSPNet', 'PAN', 'MAnet',
'Linknet', 'FPN', 'DeepLabV3Plus', 'DeepLabV3'
]
|