Yao Zhang commited on
Commit
59ecb50
·
1 Parent(s): 45ecef6
Files changed (7) hide show
  1. README.md +3 -3
  2. __init__.py +0 -0
  3. app.py +144 -0
  4. best_Dice_model.pth +3 -0
  5. packages.txt +0 -0
  6. requirements.txt +6 -0
  7. unetr2d.py +202 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: NeurIPS CellSeg
3
- emoji: 📊
4
- colorFrom: indigo
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 3.4.1
8
  app_file: app.py
 
1
  ---
2
  title: NeurIPS CellSeg
3
+ emoji: 🔥
4
+ colorFrom: yellow
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 3.4.1
8
  app_file: app.py
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Author: Yao
4
+ # Mail: zhangyao215@mails.ucas.ac.cn
5
+
6
+ import gradio as gr
7
+
8
+ import os
9
+ join = os.path.join
10
+ import time
11
+ import numpy as np
12
+ # from skimage.filters import threshold_otsu
13
+ # from skimage.measure import label
14
+ import torch
15
+ import monai
16
+ from monai.inferers import sliding_window_inference
17
+ from unetr2d import UNETR2D
18
+ import time
19
+ from skimage import io, segmentation, morphology, measure, exposure
20
+
21
+
22
+ def visualize_instance_seg_mask(mask):
23
+ image = np.zeros((mask.shape[0], mask.shape[1], 3))
24
+ labels = np.unique(mask)
25
+ label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}
26
+ for i in range(image.shape[0]):
27
+ for j in range(image.shape[1]):
28
+ image[i, j, :] = label2color[mask[i, j]]
29
+ image = image / 255
30
+ return image
31
+
32
+
33
+ def load_model(model_name, custom_model_path):
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+
36
+ if model_name == 'unet':
37
+ model = monai.networks.nets.UNet(
38
+ spatial_dims=2,
39
+ in_channels=3,
40
+ out_channels=3,
41
+ channels=(16, 32, 64, 128, 256),
42
+ strides=(2, 2, 2, 2),
43
+ num_res_units=2,
44
+ )
45
+
46
+ elif model_name == 'unetr':
47
+ model = UNETR2D(
48
+ in_channels=3,
49
+ out_channels=3,
50
+ img_size=(256, 256),
51
+ feature_size=16,
52
+ hidden_size=768,
53
+ mlp_dim=3072,
54
+ num_heads=12,
55
+ pos_embed="perceptron",
56
+ norm_name="instance",
57
+ res_block=True,
58
+ dropout_rate=0.0,
59
+ )
60
+ elif model_name == 'swinunetr':
61
+ model = monai.networks.nets.SwinUNETR(
62
+ img_size=(256, 256),
63
+ in_channels=3,
64
+ out_channels=3,
65
+ feature_size=24, # should be divisible by 12
66
+ spatial_dims=2
67
+ )
68
+ if os.path.isfile(custom_model_path):
69
+ checkpoint = torch.load(custom_model_path.resolve(), map_location=torch.device(device))
70
+ elif os.path.isfile(join(os.path.dirname(__file__), 'best_Dice_model.pth')):
71
+ checkpoint = torch.load(join(os.path.dirname(__file__), 'best_Dice_model.pth'), map_location=torch.device(device))
72
+ else:
73
+ torch.hub.download_url_to_file('https://zenodo.org/record/6792177/files/best_Dice_model.pth?download=1', join(os.path.dirname(__file__), 'work_dir/swinunetr/best_Dice_model.pth'))
74
+ checkpoint = torch.load(join(os.path.dirname(__file__), 'best_Dice_model.pth'), map_location=torch.device(device))
75
+
76
+ model.load_state_dict(checkpoint['model_state_dict'])
77
+
78
+ model = model.to(device)
79
+ model.eval()
80
+
81
+ return model
82
+
83
+
84
+ def normalize_channel(img, lower=1, upper=99):
85
+ non_zero_vals = img[np.nonzero(img)]
86
+ percentiles = np.percentile(non_zero_vals, [lower, upper])
87
+ if percentiles[1] - percentiles[0] > 0.001:
88
+ img_norm = exposure.rescale_intensity(img, in_range=(percentiles[0], percentiles[1]), out_range='uint8')
89
+ else:
90
+ img_norm = img
91
+ return img_norm.astype(np.uint8)
92
+
93
+
94
+ def preprocess(img_data):
95
+ if len(img_data.shape) == 2:
96
+ img_data = np.repeat(np.expand_dims(img_data, axis=-1), 3, axis=-1)
97
+ elif len(img_data.shape) == 3 and img_data.shape[-1] > 3:
98
+ img_data = img_data[:,:, :3]
99
+ else:
100
+ pass
101
+ pre_img_data = np.zeros(img_data.shape, dtype=np.uint8)
102
+ for i in range(3):
103
+ img_channel_i = img_data[:,:,i]
104
+ if len(img_channel_i[np.nonzero(img_channel_i)])>0:
105
+ pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=1, upper=99)
106
+ return pre_img_data
107
+
108
+
109
+ def get_seg(pre_img_data, model_name, custom_model_path, threshold):
110
+ model = load_model(model_name, custom_model_path)
111
+ #%%
112
+ roi_size = (256, 256)
113
+ sw_batch_size = 4
114
+ with torch.no_grad():
115
+ t0 = time.time()
116
+ test_npy01 = pre_img_data/np.max(pre_img_data)
117
+ # test_tensor = torch.from_numpy(np.expand_dims(test_npy01, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device)
118
+ test_tensor = torch.from_numpy(np.expand_dims(test_npy01, 0)).permute(0,3,1,2).type(torch.FloatTensor)
119
+ test_pred_out = sliding_window_inference(test_tensor, roi_size, sw_batch_size, model)
120
+ test_pred_out = torch.nn.functional.softmax(test_pred_out, dim=1) # (B, C, H, W)
121
+ test_pred_npy = test_pred_out[0,1].cpu().numpy()
122
+ # convert probability map to binary mask and apply morphological postprocessing
123
+ test_pred_mask = measure.label(morphology.remove_small_objects(morphology.remove_small_holes(test_pred_npy>threshold),16))
124
+ # tif.imwrite(join(output_path, img_name.split('.')[0]+'_label.tiff'), test_pred_mask, compression='zlib')
125
+ t1 = time.time()
126
+ # print(f'Prediction finished: {img_layer.name}; img size = {pre_img_data.shape}; costing: {t1-t0:.2f}s')
127
+ return test_pred_mask
128
+
129
+
130
+ def predict(img):
131
+ seg_labels = get_seg(preprocess(img), 'swinunetr', './best_Dice_model.pth', 0.5)
132
+ return seg_labels
133
+
134
+
135
+ demo = gr.Interface(
136
+ predict,
137
+ inputs=[gr.Image()],
138
+ outputs="image",
139
+ title="NeurIPS CellSeg Demo",
140
+ examples=[["cell_00225.png"]]
141
+ )
142
+
143
+ demo.launch()
144
+
best_Dice_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:764db0c53184da5bf743db84d8837b1f9b2c7f3b0236b43c75ff747e47c75e5a
3
+ size 75949863
packages.txt ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ numpy
2
+ scikit-image
3
+ numpy
4
+ torch
5
+ monai
6
+ einops
unetr2d.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Sun Mar 20 14:23:19 2022
5
+ Author: MONAI
6
+ """
7
+
8
+ from typing import Tuple, Union
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from monai.networks.blocks.dynunet_block import UnetOutBlock
13
+ from monai.networks.blocks import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock
14
+ from monai.networks.nets import ViT
15
+
16
+ class UNETR2D(nn.Module):
17
+ """
18
+ UNETR based on: "Hatamizadeh et al.,
19
+ UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>"
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ in_channels: int,
25
+ out_channels: int,
26
+ img_size: Tuple[int, int],
27
+ feature_size: int = 16,
28
+ hidden_size: int = 768,
29
+ mlp_dim: int = 3072,
30
+ num_heads: int = 12,
31
+ pos_embed: str = "perceptron",
32
+ norm_name: Union[Tuple, str] = "instance",
33
+ conv_block: bool = False,
34
+ res_block: bool = True,
35
+ dropout_rate: float = 0.0,
36
+ debug: bool = False
37
+ ) -> None:
38
+
39
+ super().__init__()
40
+
41
+ if not (0 <= dropout_rate <= 1):
42
+ raise AssertionError("dropout_rate should be between 0 and 1.")
43
+
44
+ if hidden_size % num_heads != 0:
45
+ raise AssertionError("hidden size should be divisible by num_heads.")
46
+
47
+ if pos_embed not in ["conv", "perceptron"]:
48
+ raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.")
49
+
50
+ self.num_layers = 12
51
+ self.patch_size = (16, 16)
52
+ self.feat_size = tuple(img_d // p_d for img_d, p_d in zip(img_size, self.patch_size))
53
+ self.hidden_size = hidden_size
54
+ self.classification = False
55
+ self.debug = debug
56
+ self.vit = ViT(
57
+ in_channels=in_channels,
58
+ img_size=img_size,
59
+ patch_size=self.patch_size,
60
+ hidden_size=hidden_size,
61
+ mlp_dim=mlp_dim,
62
+ num_layers=self.num_layers,
63
+ num_heads=num_heads,
64
+ pos_embed=pos_embed,
65
+ classification=self.classification,
66
+ dropout_rate=dropout_rate,
67
+ spatial_dims=2
68
+ )
69
+ self.encoder1 = UnetrBasicBlock(
70
+ spatial_dims=2,
71
+ in_channels=in_channels,
72
+ out_channels=feature_size,
73
+ kernel_size=3,
74
+ stride=1,
75
+ norm_name=norm_name,
76
+ res_block=res_block,
77
+ )
78
+ self.encoder2 = UnetrPrUpBlock(
79
+ spatial_dims=2,
80
+ in_channels=hidden_size,
81
+ out_channels=feature_size * 2,
82
+ num_layer=2,
83
+ kernel_size=3,
84
+ stride=1,
85
+ upsample_kernel_size=2,
86
+ norm_name=norm_name,
87
+ conv_block=conv_block,
88
+ res_block=res_block,
89
+ )
90
+ self.encoder3 = UnetrPrUpBlock(
91
+ spatial_dims=2,
92
+ in_channels=hidden_size,
93
+ out_channels=feature_size * 4,
94
+ num_layer=1,
95
+ kernel_size=3,
96
+ stride=1,
97
+ upsample_kernel_size=2,
98
+ norm_name=norm_name,
99
+ conv_block=conv_block,
100
+ res_block=res_block,
101
+ )
102
+ self.encoder4 = UnetrPrUpBlock(
103
+ spatial_dims=2,
104
+ in_channels=hidden_size,
105
+ out_channels=feature_size * 8,
106
+ num_layer=0,
107
+ kernel_size=3,
108
+ stride=1,
109
+ upsample_kernel_size=2,
110
+ norm_name=norm_name,
111
+ conv_block=conv_block,
112
+ res_block=res_block,
113
+ )
114
+ self.decoder5 = UnetrUpBlock(
115
+ spatial_dims=2,
116
+ in_channels=hidden_size,
117
+ out_channels=feature_size * 8,
118
+ kernel_size=3,
119
+ upsample_kernel_size=2,
120
+ norm_name=norm_name,
121
+ res_block=res_block,
122
+ )
123
+ self.decoder4 = UnetrUpBlock(
124
+ spatial_dims=2,
125
+ in_channels=feature_size * 8,
126
+ out_channels=feature_size * 4,
127
+ kernel_size=3,
128
+ upsample_kernel_size=2,
129
+ norm_name=norm_name,
130
+ res_block=res_block,
131
+ )
132
+ self.decoder3 = UnetrUpBlock(
133
+ spatial_dims=2,
134
+ in_channels=feature_size * 4,
135
+ out_channels=feature_size * 2,
136
+ kernel_size=3,
137
+ upsample_kernel_size=2,
138
+ norm_name=norm_name,
139
+ res_block=res_block,
140
+ )
141
+ self.decoder2 = UnetrUpBlock(
142
+ spatial_dims=2,
143
+ in_channels=feature_size * 2,
144
+ out_channels=feature_size,
145
+ kernel_size=3,
146
+ upsample_kernel_size=2,
147
+ norm_name=norm_name,
148
+ res_block=res_block,
149
+ )
150
+ self.out = UnetOutBlock(spatial_dims=2, in_channels=feature_size, out_channels=out_channels) # type: ignore
151
+
152
+ def proj_feat(self, x, hidden_size, feat_size): # x: (B, 256, 768)
153
+ x = x.view(x.size(0), feat_size[0], feat_size[1], hidden_size) # (B, 16, 16, 768)
154
+ x = x.permute(0, 3, 1, 2).contiguous() # (B, 768, 16, 16)
155
+ return x
156
+
157
+ def forward(self, x_in):
158
+ x, hidden_states_out = self.vit(x_in) # x: (B, 256,768), hidden_states_out: list, 12 elements, (B,256,768)
159
+ enc1 = self.encoder1(x_in) # (1, 16, 256, 256)
160
+ x2 = hidden_states_out[3] # (B, 256, 768)
161
+ # self.proj_feat(x2, self.hidden_size, self.feat_size): (B, 768, 16,16) -> enc2: (B,32,128,128)
162
+ enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.feat_size)) # hidden_size=768, self.feat_size=16
163
+ x3 = hidden_states_out[6] # (B, 256, 768)
164
+ enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.feat_size)) #(B, 768, 16,16) -> (B, 64, 64, 64)
165
+ x4 = hidden_states_out[9] # (B, 256, 768)
166
+ enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.feat_size)) # (B, 768, 16, 16) -> (B, 128, 32, 32)
167
+ dec4 = self.proj_feat(x, self.hidden_size, self.feat_size) # (B, 768, 16, 16)
168
+ dec3 = self.decoder5(dec4, enc4) # up -> cat -> ResConv; (B, 128, 32, 32)
169
+ dec2 = self.decoder4(dec3, enc3) # (B, 64, 64, 64)
170
+ dec1 = self.decoder3(dec2, enc2) # (B, 32, 128, 128)
171
+ out = self.decoder2(dec1, enc1) # (B, 16, 256, 256)
172
+ logits = self.out(out)
173
+
174
+ if self.debug:
175
+ return x, x2, x3,x4, hidden_states_out, enc1, enc2, enc3, enc4, dec4, dec3, dec2, dec1, logits
176
+ else:
177
+ return logits
178
+
179
+
180
+ # model = UNETR2D(
181
+ # in_channels=3, # 3 channels, R,G,B
182
+ # out_channels=3,
183
+ # img_size=(256, 256),
184
+ # feature_size=16,
185
+ # hidden_size=768,
186
+ # mlp_dim=3072,
187
+ # num_heads=12,
188
+ # pos_embed="perceptron",
189
+ # norm_name="instance",
190
+ # res_block=True,
191
+ # dropout_rate=0.0,
192
+ # debug=True
193
+ # ).cuda()
194
+
195
+ # from torchinfo import summary
196
+
197
+ # batch_size = 1
198
+ # summary(model, input_size=(batch_size, 3, 256, 256))
199
+
200
+ # x = torch.rand((1,3,256,256)).cuda()
201
+ # x, x2, x3,x4, hidden_states_out, enc1, enc2, enc3, enc4, dec4, dec3, dec2, dec1, logits = model(x)
202
+ # print(logits.shape) # torch.Size([1, 3, 256, 256])