Burdenthrive's picture
Update model.py
f73e679 verified
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
class UNet(nn.Module):
"""
UNet model for multi-class segmentation.
Designed for multi-spectral input images (e.g., 13 Sentinel-2 bands) and multiple output classes.
"""
def __init__(self,
encoder_name='tu-regnetz_d8',
encoder_weights=None,
in_channels=13, # Number of input channels (13 for Sentinel-2 multi-spectral images)
num_classes=4, # Number of output classes (e.g., clear, thick cloud, thin cloud, cloud shadow)
freeze_encoder=False): # Whether to freeze the encoder's weights
"""
Args:
encoder_weights (str or None): Weights for the encoder, typically 'imagenet' or None.
in_channels (int): Number of input channels (e.g., 13 for Sentinel-2 images).
num_classes (int): Number of output classes (e.g., 4 for clear, cloud types, and shadow).
freeze_encoder (bool): If True, freezes the encoder weights during training.
"""
super(UNet, self).__init__()
self.unet = smp.Unet(
encoder_name=encoder_name,
encoder_weights=encoder_weights,
in_channels=in_channels,
classes=num_classes,
)
if freeze_encoder:
for param in self.unet.encoder.parameters():
param.requires_grad = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the model.
Args:
x (torch.Tensor): Input tensor of shape (B, in_channels, H, W).
Returns:
torch.Tensor: Output logits of shape (B, num_classes, H, W).
"""
return self.unet(x)
@torch.no_grad()
def predict(self, x: torch.Tensor) -> torch.Tensor:
"""
Predicts multi-class segmentation labels for each pixel in the input image.
Args:
x (torch.Tensor): Input tensor of shape (B, in_channels, H, W).
Returns:
torch.Tensor: Predicted labels of shape (B, H, W).
"""
self.eval()
logits = self.forward(x) # (B, num_classes, H, W)
probs = torch.softmax(logits, dim=1) # (B, num_classes, H, W)
labels = probs.argmax(dim=1) # (B, H, W)
return labels