NeurIPS_CellSeg / app.py
Yao Zhang
debug
6404238
#!/usr/bin/env python
# coding=utf-8
# Author: Yao
# Mail: zhangyao215@mails.ucas.ac.cn
import gradio as gr
import os
join = os.path.join
import time
import random
import numpy as np
# from skimage.filters import threshold_otsu
# from skimage.measure import label
import torch
import monai
from monai.inferers import sliding_window_inference
from unetr2d import UNETR2D
import time
from skimage import io, segmentation, morphology, measure, exposure
import tifffile as tif
def visualize_instance_seg_mask(mask):
image = np.zeros((mask.shape[0], mask.shape[1], 3))
labels = np.unique(mask)
label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels if label > 0}
label2color[0] = (0, 0, 0)
for label in labels:
image[mask==label, :] = label2color[label]
# for i in range(image.shape[0]):
# for j in range(image.shape[1]):
# if np.max(label2color[mask[i, j]]) > 0:
# print('####', np.max(label2color[mask[i, j]]), np.min(label2color[mask[i, j]]))
# image[i, j, :] = label2color[mask[i, j]]
# image = image / 255
image = image.astype(np.uint8)
return image
def load_model(model_name, custom_model_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if model_name == 'unet':
model = monai.networks.nets.UNet(
spatial_dims=2,
in_channels=3,
out_channels=3,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
)
elif model_name == 'unetr':
model = UNETR2D(
in_channels=3,
out_channels=3,
img_size=(256, 256),
feature_size=16,
hidden_size=768,
mlp_dim=3072,
num_heads=12,
pos_embed="perceptron",
norm_name="instance",
res_block=True,
dropout_rate=0.0,
)
elif model_name == 'swinunetr':
model = monai.networks.nets.SwinUNETR(
img_size=(256, 256),
in_channels=3,
out_channels=3,
feature_size=24, # should be divisible by 12
spatial_dims=2
)
if os.path.isfile(custom_model_path):
# checkpoint = torch.load(custom_model_path.resolve(), map_location=torch.device(device))
checkpoint = torch.load(custom_model_path, map_location=torch.device(device))
elif os.path.isfile(join(os.path.dirname(__file__), 'best_Dice_model.pth')):
checkpoint = torch.load(join(os.path.dirname(__file__), 'best_Dice_model.pth'), map_location=torch.device(device))
else:
torch.hub.download_url_to_file('https://zenodo.org/record/6792177/files/best_Dice_model.pth?download=1', join(os.path.dirname(__file__), 'work_dir/swinunetr/best_Dice_model.pth'))
checkpoint = torch.load(join(os.path.dirname(__file__), 'best_Dice_model.pth'), map_location=torch.device(device))
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()
return model
def normalize_channel(img, lower=1, upper=99):
non_zero_vals = img[np.nonzero(img)]
percentiles = np.percentile(non_zero_vals, [lower, upper])
if percentiles[1] - percentiles[0] > 0.001:
img_norm = exposure.rescale_intensity(img, in_range=(percentiles[0], percentiles[1]), out_range='uint8')
else:
img_norm = img
return img_norm.astype(np.uint8)
def preprocess(img_data):
if len(img_data.shape) == 2:
img_data = np.repeat(np.expand_dims(img_data, axis=-1), 3, axis=-1)
elif len(img_data.shape) == 3 and img_data.shape[-1] > 3:
img_data = img_data[:,:, :3]
else:
pass
pre_img_data = np.zeros(img_data.shape, dtype=np.uint8)
for i in range(3):
img_channel_i = img_data[:,:,i]
if len(img_channel_i[np.nonzero(img_channel_i)])>0:
pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=1, upper=99)
return pre_img_data
def get_seg(pre_img_data, model_name, custom_model_path, threshold):
model = load_model(model_name, custom_model_path)
#%%
roi_size = (256, 256)
sw_batch_size = 4
with torch.no_grad():
t0 = time.time()
test_npy01 = pre_img_data/np.max(pre_img_data)
# test_tensor = torch.from_numpy(np.expand_dims(test_npy01, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device)
test_tensor = torch.from_numpy(np.expand_dims(test_npy01, 0)).permute(0,3,1,2).type(torch.FloatTensor)
test_pred_out = sliding_window_inference(test_tensor, roi_size, sw_batch_size, model)
test_pred_out = torch.nn.functional.softmax(test_pred_out, dim=1) # (B, C, H, W)
test_pred_npy = test_pred_out[0,1].cpu().numpy()
# convert probability map to binary mask and apply morphological postprocessing
test_pred_mask = measure.label(morphology.remove_small_objects(morphology.remove_small_holes(test_pred_npy>threshold),16))
# tif.imwrite(join(output_path, img_name.split('.')[0]+'_label.tiff'), test_pred_mask, compression='zlib')
t1 = time.time()
# print(f'Prediction finished: {img_layer.name}; img size = {pre_img_data.shape}; costing: {t1-t0:.2f}s')
return test_pred_mask
def predict(img, threshold=0.5):
print('##########', img.name)
img_name = img.name
if img_name.endswith('.tif') or img_name.endswith('.tiff'):
img_data = tif.imread(img_name)
else:
img_data = io.imread(img_name)
seg_labels = get_seg(preprocess(img_data), 'swinunetr', './best_Dice_model.pth', float(threshold))
seg_rgb = visualize_instance_seg_mask(seg_labels)
tif.imwrite(join(os.getcwd(), 'segmentation.tiff'), seg_labels, compression='zlib')
print(np.max(img_data), np.min(img_data))
print(np.max(seg_rgb), np.min(seg_rgb))
return img_data, seg_rgb, join(os.getcwd(), 'segmentation.tiff')
demo = gr.Interface(
predict,
# inputs=[gr.Image()],
# inputs="file",
inputs=[gr.File(label="input image"), gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="threshold")],
outputs=[gr.Image(label="image"), gr.Image(label="segmentation"), gr.File(label="download segmentation")],
title="NeurIPS CellSeg Demo",
# examples=[["cell_00225.png"]]
)
demo.launch()