| import segmentation_models_pytorch as smp | |
| def get_model(): | |
| model = smp.UnetPlusPlus( | |
| encoder_name="resnext101_32x4d", | |
| encoder_weights=None, # using your own trained weights | |
| in_channels=3, | |
| classes=3, | |
| activation=None | |
| ) | |
| return model |
| import segmentation_models_pytorch as smp | |
| def get_model(): | |
| model = smp.UnetPlusPlus( | |
| encoder_name="resnext101_32x4d", | |
| encoder_weights=None, # using your own trained weights | |
| in_channels=3, | |
| classes=3, | |
| activation=None | |
| ) | |
| return model |