File size: 283 Bytes
c5bce9d
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
import segmentation_models_pytorch as smp

def get_model():
    model = smp.UnetPlusPlus(
        encoder_name="resnext101_32x4d",
        encoder_weights=None,  # using your own trained weights
        in_channels=3,
        classes=3,
        activation=None
    )
    return model