Update model.py
Browse files
model.py
CHANGED
|
@@ -2,9 +2,9 @@ import torch
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import segmentation_models_pytorch as smp
|
| 4 |
|
| 5 |
-
class
|
| 6 |
"""
|
| 7 |
-
UNet model for multi-class segmentation
|
| 8 |
Designed for multi-spectral input images (e.g., 13 Sentinel-2 bands) and multiple output classes.
|
| 9 |
"""
|
| 10 |
|
|
@@ -21,7 +21,7 @@ class UNetWithRegNetYBackbone(nn.Module):
|
|
| 21 |
num_classes (int): Number of output classes (e.g., 4 for clear, cloud types, and shadow).
|
| 22 |
freeze_encoder (bool): If True, freezes the encoder weights during training.
|
| 23 |
"""
|
| 24 |
-
super(
|
| 25 |
|
| 26 |
|
| 27 |
self.unet = smp.Unet(
|
|
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import segmentation_models_pytorch as smp
|
| 4 |
|
| 5 |
+
class UNet(nn.Module):
|
| 6 |
"""
|
| 7 |
+
UNet model for multi-class segmentation.
|
| 8 |
Designed for multi-spectral input images (e.g., 13 Sentinel-2 bands) and multiple output classes.
|
| 9 |
"""
|
| 10 |
|
|
|
|
| 21 |
num_classes (int): Number of output classes (e.g., 4 for clear, cloud types, and shadow).
|
| 22 |
freeze_encoder (bool): If True, freezes the encoder weights during training.
|
| 23 |
"""
|
| 24 |
+
super(UNet, self).__init__()
|
| 25 |
|
| 26 |
|
| 27 |
self.unet = smp.Unet(
|