Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import skimage.transform as trans | |
| from skimage.color import rgb2gray | |
| from unet.unet import unet | |
| from unet.unet_3plus import UNet_3Plus, UNet_3Plus_DeepSup, UNet_3Plus_DeepSup_CGM | |
| def predict_model(input, unet_type): | |
| model_path = "weights" | |
| h, w = 256, 256 | |
| input_shape = [h, w, 1] | |
| output_channels = 1 | |
| batch_size = 1 | |
| # convert image into numpy array and reshape it into model's input size | |
| img = trans.resize(input, (w, h)) | |
| result_img = img.copy() | |
| img = rgb2gray(img).reshape(1, h, w, 1) | |
| # Load U-net model based on version: UNet type vo:unet, v1:unet3+, v2:unet3+ with deep supervision, v3:unet3+ with cgm | |
| if unet_type == 'v0': | |
| # load the pretrained model | |
| model_name = "unetv0_sgd500_neptune" | |
| model_file = os.path.join(model_path, model_name + ".hdf5") | |
| model = unet(model_file) | |
| elif unet_type == 'v1': | |
| # load the pretrained model | |
| model_name = "unetv1_sgd500_neptune" | |
| model_file = os.path.join(model_path, model_name + ".hdf5") | |
| model = UNet_3Plus(input_shape, output_channels, model_file) | |
| elif unet_type == 'v2': | |
| # load the pretrained model | |
| model_name = "unetv2_sgd500_neptune" | |
| model_file = os.path.join(model_path, model_name + ".hdf5") | |
| model = UNet_3Plus_DeepSup(input_shape, output_channels, model_file) | |
| else: | |
| # load the pretrained model | |
| model_name = "unetv3_sgd500_neptune" | |
| model_file = os.path.join(model_path, model_name + ".hdf5") | |
| model = UNet_3Plus_DeepSup_CGM(input_shape, output_channels, model_file) | |
| # Predict and save the results as numpy array | |
| results = model.predict(img) | |
| # Preprocess the prediction from the model depending on the model | |
| if unet_type == 'v2' or unet_type == 'v3': | |
| pred = np.copy(results[0]) | |
| else: | |
| pred = np.copy(results) | |
| pred[pred >= 0.5] = 1 | |
| pred[pred < 0.5] = 0 | |
| output = np.array(pred[0][:,:,0]) | |
| # visualize the output mask | |
| seg_color = [0, 0, 255] | |
| masked = output != 0 | |
| result_img[masked] = seg_color | |
| return result_img | |