|
|
import torch |
|
|
import torch.nn as nn |
|
|
import segmentation_models_pytorch as smp |
|
|
|
|
|
class UNet(nn.Module): |
|
|
""" |
|
|
UNet model for multi-class segmentation. |
|
|
Designed for multi-spectral input images (e.g., 13 Sentinel-2 bands) and multiple output classes. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
encoder_name='tu-regnetz_d8', |
|
|
encoder_weights=None, |
|
|
in_channels=13, |
|
|
num_classes=4, |
|
|
freeze_encoder=False): |
|
|
""" |
|
|
Args: |
|
|
encoder_weights (str or None): Weights for the encoder, typically 'imagenet' or None. |
|
|
in_channels (int): Number of input channels (e.g., 13 for Sentinel-2 images). |
|
|
num_classes (int): Number of output classes (e.g., 4 for clear, cloud types, and shadow). |
|
|
freeze_encoder (bool): If True, freezes the encoder weights during training. |
|
|
""" |
|
|
super(UNet, self).__init__() |
|
|
|
|
|
|
|
|
self.unet = smp.Unet( |
|
|
encoder_name=encoder_name, |
|
|
encoder_weights=encoder_weights, |
|
|
in_channels=in_channels, |
|
|
classes=num_classes, |
|
|
) |
|
|
|
|
|
|
|
|
if freeze_encoder: |
|
|
for param in self.unet.encoder.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass of the model. |
|
|
|
|
|
Args: |
|
|
x (torch.Tensor): Input tensor of shape (B, in_channels, H, W). |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Output logits of shape (B, num_classes, H, W). |
|
|
""" |
|
|
return self.unet(x) |
|
|
|
|
|
@torch.no_grad() |
|
|
def predict(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Predicts multi-class segmentation labels for each pixel in the input image. |
|
|
|
|
|
Args: |
|
|
x (torch.Tensor): Input tensor of shape (B, in_channels, H, W). |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Predicted labels of shape (B, H, W). |
|
|
""" |
|
|
self.eval() |
|
|
logits = self.forward(x) |
|
|
probs = torch.softmax(logits, dim=1) |
|
|
labels = probs.argmax(dim=1) |
|
|
return labels |
|
|
|