import segmentation_models_pytorch as smp def get_model(): model = smp.UnetPlusPlus( encoder_name="resnext101_32x4d", encoder_weights=None, # using your own trained weights in_channels=3, classes=3, activation=None ) return model