File size: 6,634 Bytes
59ecb50
 
 
 
 
 
 
 
 
 
f6069ec
59ecb50
 
 
 
 
 
 
 
 
9c5adc2
59ecb50
 
 
 
 
f9ca5f7
85df3b7
6404238
 
 
 
 
 
 
 
 
59ecb50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6999aa7
 
59ecb50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c5adc2
 
 
 
 
 
 
 
6404238
9c5adc2
 
 
6404238
 
9c5adc2
59ecb50
 
 
 
80fc851
563734b
9c5adc2
 
59ecb50
946565d
59ecb50
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#!/usr/bin/env python
# coding=utf-8
# Author: Yao
# Mail: zhangyao215@mails.ucas.ac.cn

import gradio as gr

import os
join = os.path.join
import time
import random
import numpy as np
# from skimage.filters import threshold_otsu
# from skimage.measure import label
import torch
import monai
from monai.inferers import sliding_window_inference
from unetr2d import UNETR2D
import time
from skimage import io, segmentation, morphology, measure, exposure
import tifffile as tif


def visualize_instance_seg_mask(mask):
    image = np.zeros((mask.shape[0], mask.shape[1], 3))
    labels = np.unique(mask)
    label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels if label > 0}
    label2color[0] = (0, 0, 0)
    for label in labels:
        image[mask==label, :] = label2color[label]
    # for i in range(image.shape[0]):
    #     for j in range(image.shape[1]):
    #         if np.max(label2color[mask[i, j]]) > 0:
    #             print('####', np.max(label2color[mask[i, j]]), np.min(label2color[mask[i, j]]))
    #         image[i, j, :] = label2color[mask[i, j]]
    # image = image / 255
    image = image.astype(np.uint8)
    return image


def load_model(model_name, custom_model_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if model_name == 'unet':
        model = monai.networks.nets.UNet(
                   spatial_dims=2,
                   in_channels=3,
                   out_channels=3,
                   channels=(16, 32, 64, 128, 256),
                   strides=(2, 2, 2, 2),
                   num_res_units=2,
               )

    elif model_name == 'unetr':
        model = UNETR2D(
                    in_channels=3,
                    out_channels=3,
                    img_size=(256, 256),
                    feature_size=16,
                    hidden_size=768,
                    mlp_dim=3072,
                    num_heads=12,
                    pos_embed="perceptron",
                    norm_name="instance",
                    res_block=True,
                    dropout_rate=0.0,
                )
    elif model_name == 'swinunetr':
        model = monai.networks.nets.SwinUNETR(
                        img_size=(256, 256), 
                        in_channels=3, 
                        out_channels=3,
                        feature_size=24, # should be divisible by 12
                        spatial_dims=2
                    )
        if os.path.isfile(custom_model_path):
            # checkpoint = torch.load(custom_model_path.resolve(), map_location=torch.device(device))
            checkpoint = torch.load(custom_model_path, map_location=torch.device(device))
        elif os.path.isfile(join(os.path.dirname(__file__), 'best_Dice_model.pth')):
            checkpoint = torch.load(join(os.path.dirname(__file__), 'best_Dice_model.pth'), map_location=torch.device(device))
        else:
            torch.hub.download_url_to_file('https://zenodo.org/record/6792177/files/best_Dice_model.pth?download=1', join(os.path.dirname(__file__), 'work_dir/swinunetr/best_Dice_model.pth'))
            checkpoint = torch.load(join(os.path.dirname(__file__), 'best_Dice_model.pth'), map_location=torch.device(device))

    model.load_state_dict(checkpoint['model_state_dict'])

    model = model.to(device)
    model.eval()

    return model


def normalize_channel(img, lower=1, upper=99):
    non_zero_vals = img[np.nonzero(img)]
    percentiles = np.percentile(non_zero_vals, [lower, upper])
    if percentiles[1] - percentiles[0] > 0.001:
        img_norm = exposure.rescale_intensity(img, in_range=(percentiles[0], percentiles[1]), out_range='uint8')
    else:
        img_norm = img
    return img_norm.astype(np.uint8)


def preprocess(img_data):
    if len(img_data.shape) == 2:
        img_data = np.repeat(np.expand_dims(img_data, axis=-1), 3, axis=-1)
    elif len(img_data.shape) == 3 and img_data.shape[-1] > 3:
        img_data = img_data[:,:, :3]
    else:
        pass
    pre_img_data = np.zeros(img_data.shape, dtype=np.uint8)
    for i in range(3):
        img_channel_i = img_data[:,:,i]
        if len(img_channel_i[np.nonzero(img_channel_i)])>0:
            pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=1, upper=99)
    return pre_img_data


def get_seg(pre_img_data, model_name, custom_model_path, threshold):
    model = load_model(model_name, custom_model_path)
    #%%
    roi_size = (256, 256)
    sw_batch_size = 4
    with torch.no_grad():
        t0 = time.time()
        test_npy01 = pre_img_data/np.max(pre_img_data)
        # test_tensor = torch.from_numpy(np.expand_dims(test_npy01, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device)
        test_tensor = torch.from_numpy(np.expand_dims(test_npy01, 0)).permute(0,3,1,2).type(torch.FloatTensor)
        test_pred_out = sliding_window_inference(test_tensor, roi_size, sw_batch_size, model)
        test_pred_out = torch.nn.functional.softmax(test_pred_out, dim=1) # (B, C, H, W)
        test_pred_npy = test_pred_out[0,1].cpu().numpy()
        # convert probability map to binary mask and apply morphological postprocessing
        test_pred_mask = measure.label(morphology.remove_small_objects(morphology.remove_small_holes(test_pred_npy>threshold),16))
        # tif.imwrite(join(output_path, img_name.split('.')[0]+'_label.tiff'), test_pred_mask, compression='zlib')
        t1 = time.time()
        # print(f'Prediction finished: {img_layer.name}; img size = {pre_img_data.shape}; costing: {t1-t0:.2f}s')
    return test_pred_mask


def predict(img, threshold=0.5):
    print('##########', img.name)
    img_name = img.name
    if img_name.endswith('.tif') or img_name.endswith('.tiff'):
        img_data = tif.imread(img_name)
    else:
        img_data = io.imread(img_name)
    seg_labels = get_seg(preprocess(img_data), 'swinunetr', './best_Dice_model.pth', float(threshold))
    seg_rgb = visualize_instance_seg_mask(seg_labels)

    tif.imwrite(join(os.getcwd(), 'segmentation.tiff'), seg_labels, compression='zlib')

    print(np.max(img_data), np.min(img_data))
    print(np.max(seg_rgb), np.min(seg_rgb))
    return img_data, seg_rgb, join(os.getcwd(), 'segmentation.tiff')


demo = gr.Interface(
    predict, 
    # inputs=[gr.Image()],
    # inputs="file",
    inputs=[gr.File(label="input image"), gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="threshold")],
    outputs=[gr.Image(label="image"), gr.Image(label="segmentation"), gr.File(label="download segmentation")],
    title="NeurIPS CellSeg Demo",
    # examples=[["cell_00225.png"]]
)

demo.launch()