Pavement-Evaluation / src /models /baselines.py
Blessing988's picture
deploy: src/models/baselines.py
84b07bc verified
"""
Baseline Model Factory
This module provides a factory function to create baseline segmentation models
using the segmentation-models-pytorch library.
Supported architectures:
- Unet
- UnetPlusPlus
- PSPNet
- PAN
- MAnet
- Linknet
- FPN
- DeepLabV3Plus
- DeepLabV3
"""
import torch
import segmentation_models_pytorch as smp
def create_model(architecture='Unet', encoder_name='resnet50', in_channels=3,
num_classes=1, encoder_weights='imagenet'):
"""
Create a baseline segmentation model using segmentation-models-pytorch.
Args:
architecture (str): Model architecture name. One of:
'Unet', 'UnetPlusPlus', 'PSPNet', 'PAN', 'MAnet',
'Linknet', 'FPN', 'DeepLabV3Plus', 'DeepLabV3'
encoder_name (str): Encoder backbone name (default: 'resnet50')
in_channels (int): Number of input channels (default: 3)
num_classes (int): Number of output classes (default: 1)
encoder_weights (str): Pretrained weights source (default: 'imagenet')
Returns:
nn.Module: Segmentation model on appropriate device
Example:
>>> model = create_model('DeepLabV3Plus', encoder_name='resnet50')
>>> x = torch.randn(1, 3, 448, 448)
>>> output = model(x)
>>> output.shape
torch.Size([1, 1, 448, 448])
"""
model = getattr(smp, architecture)(
encoder_name=encoder_name,
encoder_weights=encoder_weights,
in_channels=in_channels,
classes=num_classes
)
return model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
# List of supported architectures for reference
SUPPORTED_ARCHITECTURES = [
'Unet', 'UnetPlusPlus', 'PSPNet', 'PAN', 'MAnet',
'Linknet', 'FPN', 'DeepLabV3Plus', 'DeepLabV3'
]