Spaces:
Sleeping
Sleeping
File size: 6,634 Bytes
59ecb50 f6069ec 59ecb50 9c5adc2 59ecb50 f9ca5f7 85df3b7 6404238 59ecb50 6999aa7 59ecb50 9c5adc2 6404238 9c5adc2 6404238 9c5adc2 59ecb50 80fc851 563734b 9c5adc2 59ecb50 946565d 59ecb50 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
#!/usr/bin/env python
# coding=utf-8
# Author: Yao
# Mail: zhangyao215@mails.ucas.ac.cn
import gradio as gr
import os
join = os.path.join
import time
import random
import numpy as np
# from skimage.filters import threshold_otsu
# from skimage.measure import label
import torch
import monai
from monai.inferers import sliding_window_inference
from unetr2d import UNETR2D
import time
from skimage import io, segmentation, morphology, measure, exposure
import tifffile as tif
def visualize_instance_seg_mask(mask):
image = np.zeros((mask.shape[0], mask.shape[1], 3))
labels = np.unique(mask)
label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels if label > 0}
label2color[0] = (0, 0, 0)
for label in labels:
image[mask==label, :] = label2color[label]
# for i in range(image.shape[0]):
# for j in range(image.shape[1]):
# if np.max(label2color[mask[i, j]]) > 0:
# print('####', np.max(label2color[mask[i, j]]), np.min(label2color[mask[i, j]]))
# image[i, j, :] = label2color[mask[i, j]]
# image = image / 255
image = image.astype(np.uint8)
return image
def load_model(model_name, custom_model_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if model_name == 'unet':
model = monai.networks.nets.UNet(
spatial_dims=2,
in_channels=3,
out_channels=3,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
)
elif model_name == 'unetr':
model = UNETR2D(
in_channels=3,
out_channels=3,
img_size=(256, 256),
feature_size=16,
hidden_size=768,
mlp_dim=3072,
num_heads=12,
pos_embed="perceptron",
norm_name="instance",
res_block=True,
dropout_rate=0.0,
)
elif model_name == 'swinunetr':
model = monai.networks.nets.SwinUNETR(
img_size=(256, 256),
in_channels=3,
out_channels=3,
feature_size=24, # should be divisible by 12
spatial_dims=2
)
if os.path.isfile(custom_model_path):
# checkpoint = torch.load(custom_model_path.resolve(), map_location=torch.device(device))
checkpoint = torch.load(custom_model_path, map_location=torch.device(device))
elif os.path.isfile(join(os.path.dirname(__file__), 'best_Dice_model.pth')):
checkpoint = torch.load(join(os.path.dirname(__file__), 'best_Dice_model.pth'), map_location=torch.device(device))
else:
torch.hub.download_url_to_file('https://zenodo.org/record/6792177/files/best_Dice_model.pth?download=1', join(os.path.dirname(__file__), 'work_dir/swinunetr/best_Dice_model.pth'))
checkpoint = torch.load(join(os.path.dirname(__file__), 'best_Dice_model.pth'), map_location=torch.device(device))
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()
return model
def normalize_channel(img, lower=1, upper=99):
non_zero_vals = img[np.nonzero(img)]
percentiles = np.percentile(non_zero_vals, [lower, upper])
if percentiles[1] - percentiles[0] > 0.001:
img_norm = exposure.rescale_intensity(img, in_range=(percentiles[0], percentiles[1]), out_range='uint8')
else:
img_norm = img
return img_norm.astype(np.uint8)
def preprocess(img_data):
if len(img_data.shape) == 2:
img_data = np.repeat(np.expand_dims(img_data, axis=-1), 3, axis=-1)
elif len(img_data.shape) == 3 and img_data.shape[-1] > 3:
img_data = img_data[:,:, :3]
else:
pass
pre_img_data = np.zeros(img_data.shape, dtype=np.uint8)
for i in range(3):
img_channel_i = img_data[:,:,i]
if len(img_channel_i[np.nonzero(img_channel_i)])>0:
pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=1, upper=99)
return pre_img_data
def get_seg(pre_img_data, model_name, custom_model_path, threshold):
model = load_model(model_name, custom_model_path)
#%%
roi_size = (256, 256)
sw_batch_size = 4
with torch.no_grad():
t0 = time.time()
test_npy01 = pre_img_data/np.max(pre_img_data)
# test_tensor = torch.from_numpy(np.expand_dims(test_npy01, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device)
test_tensor = torch.from_numpy(np.expand_dims(test_npy01, 0)).permute(0,3,1,2).type(torch.FloatTensor)
test_pred_out = sliding_window_inference(test_tensor, roi_size, sw_batch_size, model)
test_pred_out = torch.nn.functional.softmax(test_pred_out, dim=1) # (B, C, H, W)
test_pred_npy = test_pred_out[0,1].cpu().numpy()
# convert probability map to binary mask and apply morphological postprocessing
test_pred_mask = measure.label(morphology.remove_small_objects(morphology.remove_small_holes(test_pred_npy>threshold),16))
# tif.imwrite(join(output_path, img_name.split('.')[0]+'_label.tiff'), test_pred_mask, compression='zlib')
t1 = time.time()
# print(f'Prediction finished: {img_layer.name}; img size = {pre_img_data.shape}; costing: {t1-t0:.2f}s')
return test_pred_mask
def predict(img, threshold=0.5):
print('##########', img.name)
img_name = img.name
if img_name.endswith('.tif') or img_name.endswith('.tiff'):
img_data = tif.imread(img_name)
else:
img_data = io.imread(img_name)
seg_labels = get_seg(preprocess(img_data), 'swinunetr', './best_Dice_model.pth', float(threshold))
seg_rgb = visualize_instance_seg_mask(seg_labels)
tif.imwrite(join(os.getcwd(), 'segmentation.tiff'), seg_labels, compression='zlib')
print(np.max(img_data), np.min(img_data))
print(np.max(seg_rgb), np.min(seg_rgb))
return img_data, seg_rgb, join(os.getcwd(), 'segmentation.tiff')
demo = gr.Interface(
predict,
# inputs=[gr.Image()],
# inputs="file",
inputs=[gr.File(label="input image"), gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="threshold")],
outputs=[gr.Image(label="image"), gr.Image(label="segmentation"), gr.File(label="download segmentation")],
title="NeurIPS CellSeg Demo",
# examples=[["cell_00225.png"]]
)
demo.launch()
|