Burdenthrive commited on
Commit
aa8d960
·
verified ·
1 Parent(s): 512b655

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +3 -3
model.py CHANGED
@@ -2,9 +2,9 @@ import torch
2
  import torch.nn as nn
3
  import segmentation_models_pytorch as smp
4
 
5
- class UNetWithRegNetYBackbone(nn.Module):
6
  """
7
- UNet model for multi-class segmentation with RegNetY_006 backbone.
8
  Designed for multi-spectral input images (e.g., 13 Sentinel-2 bands) and multiple output classes.
9
  """
10
 
@@ -21,7 +21,7 @@ class UNetWithRegNetYBackbone(nn.Module):
21
  num_classes (int): Number of output classes (e.g., 4 for clear, cloud types, and shadow).
22
  freeze_encoder (bool): If True, freezes the encoder weights during training.
23
  """
24
- super(UNetWithRegNetYBackbone, self).__init__()
25
 
26
 
27
  self.unet = smp.Unet(
 
2
  import torch.nn as nn
3
  import segmentation_models_pytorch as smp
4
 
5
+ class UNet(nn.Module):
6
  """
7
+ UNet model for multi-class segmentation.
8
  Designed for multi-spectral input images (e.g., 13 Sentinel-2 bands) and multiple output classes.
9
  """
10
 
 
21
  num_classes (int): Number of output classes (e.g., 4 for clear, cloud types, and shadow).
22
  freeze_encoder (bool): If True, freezes the encoder weights during training.
23
  """
24
+ super(UNet, self).__init__()
25
 
26
 
27
  self.unet = smp.Unet(