Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # coding=utf-8 | |
| # Author: Yao | |
| # Mail: zhangyao215@mails.ucas.ac.cn | |
| import gradio as gr | |
| import os | |
| join = os.path.join | |
| import time | |
| import random | |
| import numpy as np | |
| # from skimage.filters import threshold_otsu | |
| # from skimage.measure import label | |
| import torch | |
| import monai | |
| from monai.inferers import sliding_window_inference | |
| from unetr2d import UNETR2D | |
| import time | |
| from skimage import io, segmentation, morphology, measure, exposure | |
| import tifffile as tif | |
| def visualize_instance_seg_mask(mask): | |
| image = np.zeros((mask.shape[0], mask.shape[1], 3)) | |
| labels = np.unique(mask) | |
| label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels if label > 0} | |
| label2color[0] = (0, 0, 0) | |
| for label in labels: | |
| image[mask==label, :] = label2color[label] | |
| # for i in range(image.shape[0]): | |
| # for j in range(image.shape[1]): | |
| # if np.max(label2color[mask[i, j]]) > 0: | |
| # print('####', np.max(label2color[mask[i, j]]), np.min(label2color[mask[i, j]])) | |
| # image[i, j, :] = label2color[mask[i, j]] | |
| # image = image / 255 | |
| image = image.astype(np.uint8) | |
| return image | |
| def load_model(model_name, custom_model_path): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if model_name == 'unet': | |
| model = monai.networks.nets.UNet( | |
| spatial_dims=2, | |
| in_channels=3, | |
| out_channels=3, | |
| channels=(16, 32, 64, 128, 256), | |
| strides=(2, 2, 2, 2), | |
| num_res_units=2, | |
| ) | |
| elif model_name == 'unetr': | |
| model = UNETR2D( | |
| in_channels=3, | |
| out_channels=3, | |
| img_size=(256, 256), | |
| feature_size=16, | |
| hidden_size=768, | |
| mlp_dim=3072, | |
| num_heads=12, | |
| pos_embed="perceptron", | |
| norm_name="instance", | |
| res_block=True, | |
| dropout_rate=0.0, | |
| ) | |
| elif model_name == 'swinunetr': | |
| model = monai.networks.nets.SwinUNETR( | |
| img_size=(256, 256), | |
| in_channels=3, | |
| out_channels=3, | |
| feature_size=24, # should be divisible by 12 | |
| spatial_dims=2 | |
| ) | |
| if os.path.isfile(custom_model_path): | |
| # checkpoint = torch.load(custom_model_path.resolve(), map_location=torch.device(device)) | |
| checkpoint = torch.load(custom_model_path, map_location=torch.device(device)) | |
| elif os.path.isfile(join(os.path.dirname(__file__), 'best_Dice_model.pth')): | |
| checkpoint = torch.load(join(os.path.dirname(__file__), 'best_Dice_model.pth'), map_location=torch.device(device)) | |
| else: | |
| torch.hub.download_url_to_file('https://zenodo.org/record/6792177/files/best_Dice_model.pth?download=1', join(os.path.dirname(__file__), 'work_dir/swinunetr/best_Dice_model.pth')) | |
| checkpoint = torch.load(join(os.path.dirname(__file__), 'best_Dice_model.pth'), map_location=torch.device(device)) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model = model.to(device) | |
| model.eval() | |
| return model | |
| def normalize_channel(img, lower=1, upper=99): | |
| non_zero_vals = img[np.nonzero(img)] | |
| percentiles = np.percentile(non_zero_vals, [lower, upper]) | |
| if percentiles[1] - percentiles[0] > 0.001: | |
| img_norm = exposure.rescale_intensity(img, in_range=(percentiles[0], percentiles[1]), out_range='uint8') | |
| else: | |
| img_norm = img | |
| return img_norm.astype(np.uint8) | |
| def preprocess(img_data): | |
| if len(img_data.shape) == 2: | |
| img_data = np.repeat(np.expand_dims(img_data, axis=-1), 3, axis=-1) | |
| elif len(img_data.shape) == 3 and img_data.shape[-1] > 3: | |
| img_data = img_data[:,:, :3] | |
| else: | |
| pass | |
| pre_img_data = np.zeros(img_data.shape, dtype=np.uint8) | |
| for i in range(3): | |
| img_channel_i = img_data[:,:,i] | |
| if len(img_channel_i[np.nonzero(img_channel_i)])>0: | |
| pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=1, upper=99) | |
| return pre_img_data | |
| def get_seg(pre_img_data, model_name, custom_model_path, threshold): | |
| model = load_model(model_name, custom_model_path) | |
| #%% | |
| roi_size = (256, 256) | |
| sw_batch_size = 4 | |
| with torch.no_grad(): | |
| t0 = time.time() | |
| test_npy01 = pre_img_data/np.max(pre_img_data) | |
| # test_tensor = torch.from_numpy(np.expand_dims(test_npy01, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device) | |
| test_tensor = torch.from_numpy(np.expand_dims(test_npy01, 0)).permute(0,3,1,2).type(torch.FloatTensor) | |
| test_pred_out = sliding_window_inference(test_tensor, roi_size, sw_batch_size, model) | |
| test_pred_out = torch.nn.functional.softmax(test_pred_out, dim=1) # (B, C, H, W) | |
| test_pred_npy = test_pred_out[0,1].cpu().numpy() | |
| # convert probability map to binary mask and apply morphological postprocessing | |
| test_pred_mask = measure.label(morphology.remove_small_objects(morphology.remove_small_holes(test_pred_npy>threshold),16)) | |
| # tif.imwrite(join(output_path, img_name.split('.')[0]+'_label.tiff'), test_pred_mask, compression='zlib') | |
| t1 = time.time() | |
| # print(f'Prediction finished: {img_layer.name}; img size = {pre_img_data.shape}; costing: {t1-t0:.2f}s') | |
| return test_pred_mask | |
| def predict(img, threshold=0.5): | |
| print('##########', img.name) | |
| img_name = img.name | |
| if img_name.endswith('.tif') or img_name.endswith('.tiff'): | |
| img_data = tif.imread(img_name) | |
| else: | |
| img_data = io.imread(img_name) | |
| seg_labels = get_seg(preprocess(img_data), 'swinunetr', './best_Dice_model.pth', float(threshold)) | |
| seg_rgb = visualize_instance_seg_mask(seg_labels) | |
| tif.imwrite(join(os.getcwd(), 'segmentation.tiff'), seg_labels, compression='zlib') | |
| print(np.max(img_data), np.min(img_data)) | |
| print(np.max(seg_rgb), np.min(seg_rgb)) | |
| return img_data, seg_rgb, join(os.getcwd(), 'segmentation.tiff') | |
| demo = gr.Interface( | |
| predict, | |
| # inputs=[gr.Image()], | |
| # inputs="file", | |
| inputs=[gr.File(label="input image"), gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="threshold")], | |
| outputs=[gr.Image(label="image"), gr.Image(label="segmentation"), gr.File(label="download segmentation")], | |
| title="NeurIPS CellSeg Demo", | |
| # examples=[["cell_00225.png"]] | |
| ) | |
| demo.launch() | |