File size: 2,430 Bytes
512b655
 
 
 
aa8d960
512b655
aa8d960
512b655
 
 
 
 
f73e679
512b655
 
 
 
 
 
 
 
 
 
aa8d960
512b655
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp

class UNet(nn.Module):
    """
    UNet model for multi-class segmentation.
    Designed for multi-spectral input images (e.g., 13 Sentinel-2 bands) and multiple output classes.
    """

    def __init__(self,
                 encoder_name='tu-regnetz_d8', 
                 encoder_weights=None, 
                 in_channels=13,              # Number of input channels (13 for Sentinel-2 multi-spectral images)
                 num_classes=4,               # Number of output classes (e.g., clear, thick cloud, thin cloud, cloud shadow)
                 freeze_encoder=False):       # Whether to freeze the encoder's weights
        """
        Args:
            encoder_weights (str or None): Weights for the encoder, typically 'imagenet' or None.
            in_channels (int): Number of input channels (e.g., 13 for Sentinel-2 images).
            num_classes (int): Number of output classes (e.g., 4 for clear, cloud types, and shadow).
            freeze_encoder (bool): If True, freezes the encoder weights during training.
        """
        super(UNet, self).__init__()


        self.unet = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=num_classes,
        )

    
        if freeze_encoder:
            for param in self.unet.encoder.parameters():
                param.requires_grad = False

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the model.

        Args:
            x (torch.Tensor): Input tensor of shape (B, in_channels, H, W).

        Returns:
            torch.Tensor: Output logits of shape (B, num_classes, H, W).
        """
        return self.unet(x)

    @torch.no_grad()
    def predict(self, x: torch.Tensor) -> torch.Tensor:
        """
        Predicts multi-class segmentation labels for each pixel in the input image.

        Args:
            x (torch.Tensor): Input tensor of shape (B, in_channels, H, W).

        Returns:
            torch.Tensor: Predicted labels of shape (B, H, W).
        """
        self.eval()
        logits = self.forward(x)                   # (B, num_classes, H, W)
        probs = torch.softmax(logits, dim=1)       # (B, num_classes, H, W)
        labels = probs.argmax(dim=1)               # (B, H, W)
        return labels