srijaydeshpande commited on
Commit
629b17d
·
verified ·
1 Parent(s): 8a06827

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +284 -0
app.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import colorsys
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from PIL import Image
6
+ from metrics import *
7
+ import torchvision.transforms as T
8
+ import gradio as gr
9
+ import matplotlib.pyplot as plt
10
+ import tempfile
11
+
12
+ # image_path = r'F:\Datasets\BCSS_InstaDeep\splits\test\images\TCGA-A1-A0SK-DX1_xmin45749_ymin25055_MPP.png'
13
+
14
+ class SPADE(nn.Module):
15
+ def __init__(self, norm_nc, label_nc, norm):
16
+ super().__init__()
17
+
18
+
19
+ if norm == 'instance':
20
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
21
+ elif norm == 'batch':
22
+ self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
23
+
24
+ # The dimension of the intermediate embedding space. Yes, hardcoded.
25
+ nhidden = 128
26
+ ks = 3
27
+ pw = ks // 2
28
+ self.mlp_shared = nn.Sequential(
29
+ nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
30
+ nn.ReLU()
31
+ )
32
+ self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
33
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
34
+
35
+ def forward(self, x, segmap):
36
+
37
+ # Part 1. generate parameter-free normalized activations
38
+ normalized = self.param_free_norm(x)
39
+
40
+ # Part 2. produce scaling and bias conditioned on semantic map
41
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
42
+ actv = self.mlp_shared(segmap)
43
+ gamma = self.mlp_gamma(actv)
44
+ beta = self.mlp_beta(actv)
45
+
46
+ # apply scale and bias
47
+ out = normalized * (1 + gamma) + beta
48
+
49
+ return out
50
+
51
+ class SPADEResnetBlock(nn.Module):
52
+ def __init__(self, fin, fout):
53
+ super().__init__()
54
+ # Attributes
55
+ self.learned_shortcut = (fin != fout)
56
+ fmiddle = min(fin, fout)
57
+
58
+ # create conv layers
59
+ self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
60
+ self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
61
+ if self.learned_shortcut:
62
+ self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
63
+
64
+ # define normalization layers
65
+ self.norm_0 = SPADE(fin, 3, norm='instance')
66
+ self.norm_1 = SPADE(fmiddle, 3, norm='instance')
67
+ if self.learned_shortcut:
68
+ self.norm_s = SPADE(fin, 3, norm='instance')
69
+
70
+ def forward(self, x, seg):
71
+ x_s = self.shortcut(x, seg)
72
+
73
+ dx = self.conv_0(self.actvn(self.norm_0(x, seg)))
74
+ dx = self.conv_1(self.actvn(self.norm_1(dx, seg)))
75
+
76
+ out = x_s + dx
77
+
78
+ return out
79
+
80
+ def shortcut(self, x, seg):
81
+ if self.learned_shortcut:
82
+ x_s = self.conv_s(self.norm_s(x, seg))
83
+ else:
84
+ x_s = x
85
+ return x_s
86
+
87
+ def actvn(self, x):
88
+ return F.leaky_relu(x, 2e-1)
89
+
90
+ class ResnetBlock(nn.Module):
91
+
92
+ def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
93
+ super(ResnetBlock, self).__init__()
94
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)
95
+
96
+ def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
97
+ conv_block = []
98
+ p = 0
99
+ if padding_type == 'reflect':
100
+ conv_block += [nn.ReflectionPad2d(1)]
101
+ elif padding_type == 'replicate':
102
+ conv_block += [nn.ReplicationPad2d(1)]
103
+ elif padding_type == 'zero':
104
+ p = 1
105
+ else:
106
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
107
+
108
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
109
+ norm_layer(dim),
110
+ activation]
111
+
112
+ if use_dropout:
113
+ conv_block += [nn.Dropout(0.5)]
114
+
115
+ p = 0
116
+ if padding_type == 'reflect':
117
+ conv_block += [nn.ReflectionPad2d(1)]
118
+ elif padding_type == 'replicate':
119
+ conv_block += [nn.ReplicationPad2d(1)]
120
+ elif padding_type == 'zero':
121
+ p = 1
122
+ else:
123
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
124
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
125
+ norm_layer(dim)]
126
+
127
+ return nn.Sequential(*conv_block)
128
+
129
+ def forward(self, x):
130
+ out = x + self.conv_block(x)
131
+ return out
132
+
133
+ class SPADEResNet(torch.nn.Module):
134
+
135
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=5, norm_layer=nn.BatchNorm2d,
136
+ padding_type='reflect'):
137
+ assert (n_blocks >= 0)
138
+ super(SPADEResNet, self).__init__()
139
+ activation = nn.ReLU(True)
140
+
141
+ downsampler = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
142
+
143
+ ### downsample
144
+ for i in range(n_downsampling):
145
+ mult = 2 ** i
146
+ downsampler += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
147
+ norm_layer(ngf * mult * 2), activation]
148
+ self.downsampler = nn.Sequential(*downsampler)
149
+
150
+ ### resnet blocks
151
+ mult = 2 ** n_downsampling
152
+ self.resnetblocks1 = SPADEResnetBlock(ngf * mult, ngf * mult)
153
+ self.resnetblocks2 = SPADEResnetBlock(ngf * mult, ngf * mult)
154
+ self.resnetblocks3 = SPADEResnetBlock(ngf * mult, ngf * mult)
155
+ self.resnetblocks4 = SPADEResnetBlock(ngf * mult, ngf * mult)
156
+ self.resnetblocks5 = SPADEResnetBlock(ngf * mult, ngf * mult)
157
+
158
+ ### upsample
159
+ upsampler = []
160
+ for i in range(n_downsampling):
161
+ mult = 2 ** (n_downsampling - i)
162
+ upsampler += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1,
163
+ output_padding=1),
164
+ norm_layer(int(ngf * mult / 2)), activation]
165
+
166
+ upsampler += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
167
+
168
+ self.upsampler = nn.Sequential(*upsampler)
169
+
170
+ def forward(self, input):
171
+ downsampled = self.downsampler(input)
172
+ resnet1 = self.resnetblocks1(downsampled, input)
173
+ resnet2 = self.resnetblocks1(resnet1, input)
174
+ resnet3 = self.resnetblocks1(resnet2, input)
175
+ resnet4 = self.resnetblocks1(resnet3, input)
176
+ resnet5 = self.resnetblocks1(resnet4, input)
177
+ upsampled = self.upsampler(resnet5)
178
+ return upsampled
179
+
180
+ def generate_colors(n):
181
+ brightness = 0.7
182
+ hsv = [(i / n, 1, brightness) for i in range(n)]
183
+ colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
184
+ colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),colors))
185
+ return colors
186
+
187
+ def generate_colored_image(labels):
188
+ colors = generate_colors(6)
189
+ w, h = labels.shape
190
+ new_mk = np.empty([w, h, 3])
191
+ for i in range(0,w):
192
+ for j in range(0,h):
193
+ new_mk[i][j] = colors[labels[i][j]]
194
+ # new_mk = new_mk / 255.0
195
+ new_mk = new_mk.astype(np.uint8)
196
+ return Image.fromarray(new_mk)
197
+
198
+ def predict_wsi(image):
199
+ patch_size = 768
200
+ stride = 700 # stride is kept relatively lower than the tile size so as to allow some overlap while constructing bigger regions
201
+ generator_output_size = patch_size
202
+ num_classes=5
203
+ pred_labels = torch.zeros(1, num_classes+1, image.shape[2], image.shape[3]).cuda()
204
+ counter_tensor = torch.zeros(1, 1, image.shape[2], image.shape[3]).cuda()
205
+ for i in range(0, image.shape[2] - patch_size + 1, stride):
206
+ for j in range(0, image.shape[3] - patch_size + 1, stride):
207
+ i_lowered = min(i, image.shape[2] - patch_size)
208
+ j_lowered = min(j, image.shape[3] - patch_size)
209
+ patch = image[:, :, i_lowered:i_lowered + patch_size, j_lowered:j_lowered + patch_size]
210
+ pred_labels_patch = model(patch.float())
211
+ update_region_i = i_lowered + (patch_size - generator_output_size) // 2
212
+ update_region_j = j_lowered + (patch_size - generator_output_size) // 2
213
+ pred_labels[:, :, update_region_i:update_region_i + generator_output_size,
214
+ update_region_j:update_region_j + generator_output_size] += pred_labels_patch
215
+ counter_tensor[:, :, update_region_i:update_region_i + generator_output_size,
216
+ update_region_j:update_region_j + generator_output_size] += 1
217
+ pred_labels /= counter_tensor
218
+ return pred_labels
219
+
220
+ def segment_image(image):
221
+ # img = Image.open(image_path)
222
+ img = image
223
+ img = np.asarray(img)
224
+ if (np.max(img) > 100):
225
+ img = img / 255.0
226
+ transform = T.Compose([T.ToTensor()])
227
+ image = transform(img)
228
+ image = image[None, :]
229
+ with torch.no_grad():
230
+ pred_labels = predict_wsi(image.float())
231
+ pred_labels = F.softmax(pred_labels, dim=1)
232
+ pred_labels_probs = pred_labels.cpu().numpy()
233
+ pred_labels = np.argmax(pred_labels_probs, axis=1)
234
+ pred_labels = pred_labels[0]
235
+ image = generate_colored_image(pred_labels)
236
+ class_labels = ['tumor', 'stroma', 'inflammatory', 'necrosis', 'others']
237
+ pixels_counts = []
238
+ total=0
239
+ print(np.unique(pred_labels))
240
+ for i in range(1,len(class_labels)+1):
241
+ current_count=np.sum(pred_labels == i)
242
+ pixels_counts.append(current_count)
243
+ total+=current_count
244
+ pixels_counts = [(value / total) * 100 for value in pixels_counts]
245
+ print(pixels_counts)
246
+ plt.figure(figsize=(10, 6))
247
+ bar_width = 0.15
248
+ plt.bar(class_labels, pixels_counts, color='blue', width=bar_width)
249
+ plt.xticks(rotation=45, ha='right')
250
+ plt.xlabel('Tissue types', fontsize=17)
251
+ plt.ylabel('Class Percentage', fontsize=17)
252
+ plt.title('Classes distribution', fontsize=18)
253
+ plt.xticks(fontsize=16)
254
+ plt.yticks(fontsize=16)
255
+ plt.tight_layout()
256
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmpfile:
257
+ plt.savefig(tmpfile.name)
258
+ temp_filename = tmpfile.name
259
+ stats = Image.open(temp_filename)
260
+
261
+ legend = Image.open('legend.png')
262
+
263
+ # new_width = max(stats.width, legend.width)
264
+ # new_height = stats.height + legend.height
265
+ # new_image = Image.new("RGB", (new_width, new_height), (255, 255, 255))
266
+ # new_image.paste(stats, (0, 0))
267
+ # new_image.paste(legend, (image.height,0))
268
+
269
+ return image, legend, stats
270
+
271
+ model_path = './models/spaderesnet/spaderesnet16.pt'
272
+ model = SPADEResNet(input_nc=3, output_nc=6)
273
+ model = nn.DataParallel(model)
274
+ model = model.cuda()
275
+ model.load_state_dict(torch.load(model_path), strict=True)
276
+
277
+ demo = gr.Interface(
278
+ segment_image,
279
+ inputs=gr.Image(),
280
+ outputs=["image", "image", "image"],
281
+ title="Breast Cancer Semantic Segmentation"
282
+ )
283
+
284
+ demo.launch()