srijaydeshpande commited on
Commit
eafd2db
·
verified ·
1 Parent(s): e173465

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +285 -284
app.py CHANGED
@@ -1,284 +1,285 @@
1
- import colorsys
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from PIL import Image
6
- from metrics import *
7
- import torchvision.transforms as T
8
- import gradio as gr
9
- import matplotlib.pyplot as plt
10
- import tempfile
11
-
12
- # image_path = r'F:\Datasets\BCSS_InstaDeep\splits\test\images\TCGA-A1-A0SK-DX1_xmin45749_ymin25055_MPP.png'
13
-
14
- class SPADE(nn.Module):
15
- def __init__(self, norm_nc, label_nc, norm):
16
- super().__init__()
17
-
18
-
19
- if norm == 'instance':
20
- self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
21
- elif norm == 'batch':
22
- self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
23
-
24
- # The dimension of the intermediate embedding space. Yes, hardcoded.
25
- nhidden = 128
26
- ks = 3
27
- pw = ks // 2
28
- self.mlp_shared = nn.Sequential(
29
- nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
30
- nn.ReLU()
31
- )
32
- self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
33
- self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
34
-
35
- def forward(self, x, segmap):
36
-
37
- # Part 1. generate parameter-free normalized activations
38
- normalized = self.param_free_norm(x)
39
-
40
- # Part 2. produce scaling and bias conditioned on semantic map
41
- segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
42
- actv = self.mlp_shared(segmap)
43
- gamma = self.mlp_gamma(actv)
44
- beta = self.mlp_beta(actv)
45
-
46
- # apply scale and bias
47
- out = normalized * (1 + gamma) + beta
48
-
49
- return out
50
-
51
- class SPADEResnetBlock(nn.Module):
52
- def __init__(self, fin, fout):
53
- super().__init__()
54
- # Attributes
55
- self.learned_shortcut = (fin != fout)
56
- fmiddle = min(fin, fout)
57
-
58
- # create conv layers
59
- self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
60
- self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
61
- if self.learned_shortcut:
62
- self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
63
-
64
- # define normalization layers
65
- self.norm_0 = SPADE(fin, 3, norm='instance')
66
- self.norm_1 = SPADE(fmiddle, 3, norm='instance')
67
- if self.learned_shortcut:
68
- self.norm_s = SPADE(fin, 3, norm='instance')
69
-
70
- def forward(self, x, seg):
71
- x_s = self.shortcut(x, seg)
72
-
73
- dx = self.conv_0(self.actvn(self.norm_0(x, seg)))
74
- dx = self.conv_1(self.actvn(self.norm_1(dx, seg)))
75
-
76
- out = x_s + dx
77
-
78
- return out
79
-
80
- def shortcut(self, x, seg):
81
- if self.learned_shortcut:
82
- x_s = self.conv_s(self.norm_s(x, seg))
83
- else:
84
- x_s = x
85
- return x_s
86
-
87
- def actvn(self, x):
88
- return F.leaky_relu(x, 2e-1)
89
-
90
- class ResnetBlock(nn.Module):
91
-
92
- def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
93
- super(ResnetBlock, self).__init__()
94
- self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)
95
-
96
- def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
97
- conv_block = []
98
- p = 0
99
- if padding_type == 'reflect':
100
- conv_block += [nn.ReflectionPad2d(1)]
101
- elif padding_type == 'replicate':
102
- conv_block += [nn.ReplicationPad2d(1)]
103
- elif padding_type == 'zero':
104
- p = 1
105
- else:
106
- raise NotImplementedError('padding [%s] is not implemented' % padding_type)
107
-
108
- conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
109
- norm_layer(dim),
110
- activation]
111
-
112
- if use_dropout:
113
- conv_block += [nn.Dropout(0.5)]
114
-
115
- p = 0
116
- if padding_type == 'reflect':
117
- conv_block += [nn.ReflectionPad2d(1)]
118
- elif padding_type == 'replicate':
119
- conv_block += [nn.ReplicationPad2d(1)]
120
- elif padding_type == 'zero':
121
- p = 1
122
- else:
123
- raise NotImplementedError('padding [%s] is not implemented' % padding_type)
124
- conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
125
- norm_layer(dim)]
126
-
127
- return nn.Sequential(*conv_block)
128
-
129
- def forward(self, x):
130
- out = x + self.conv_block(x)
131
- return out
132
-
133
- class SPADEResNet(torch.nn.Module):
134
-
135
- def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=5, norm_layer=nn.BatchNorm2d,
136
- padding_type='reflect'):
137
- assert (n_blocks >= 0)
138
- super(SPADEResNet, self).__init__()
139
- activation = nn.ReLU(True)
140
-
141
- downsampler = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
142
-
143
- ### downsample
144
- for i in range(n_downsampling):
145
- mult = 2 ** i
146
- downsampler += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
147
- norm_layer(ngf * mult * 2), activation]
148
- self.downsampler = nn.Sequential(*downsampler)
149
-
150
- ### resnet blocks
151
- mult = 2 ** n_downsampling
152
- self.resnetblocks1 = SPADEResnetBlock(ngf * mult, ngf * mult)
153
- self.resnetblocks2 = SPADEResnetBlock(ngf * mult, ngf * mult)
154
- self.resnetblocks3 = SPADEResnetBlock(ngf * mult, ngf * mult)
155
- self.resnetblocks4 = SPADEResnetBlock(ngf * mult, ngf * mult)
156
- self.resnetblocks5 = SPADEResnetBlock(ngf * mult, ngf * mult)
157
-
158
- ### upsample
159
- upsampler = []
160
- for i in range(n_downsampling):
161
- mult = 2 ** (n_downsampling - i)
162
- upsampler += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1,
163
- output_padding=1),
164
- norm_layer(int(ngf * mult / 2)), activation]
165
-
166
- upsampler += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
167
-
168
- self.upsampler = nn.Sequential(*upsampler)
169
-
170
- def forward(self, input):
171
- downsampled = self.downsampler(input)
172
- resnet1 = self.resnetblocks1(downsampled, input)
173
- resnet2 = self.resnetblocks1(resnet1, input)
174
- resnet3 = self.resnetblocks1(resnet2, input)
175
- resnet4 = self.resnetblocks1(resnet3, input)
176
- resnet5 = self.resnetblocks1(resnet4, input)
177
- upsampled = self.upsampler(resnet5)
178
- return upsampled
179
-
180
- def generate_colors(n):
181
- brightness = 0.7
182
- hsv = [(i / n, 1, brightness) for i in range(n)]
183
- colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
184
- colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),colors))
185
- return colors
186
-
187
- def generate_colored_image(labels):
188
- colors = generate_colors(6)
189
- w, h = labels.shape
190
- new_mk = np.empty([w, h, 3])
191
- for i in range(0,w):
192
- for j in range(0,h):
193
- new_mk[i][j] = colors[labels[i][j]]
194
- # new_mk = new_mk / 255.0
195
- new_mk = new_mk.astype(np.uint8)
196
- return Image.fromarray(new_mk)
197
-
198
- def predict_wsi(image):
199
- patch_size = 768
200
- stride = 700 # stride is kept relatively lower than the tile size so as to allow some overlap while constructing bigger regions
201
- generator_output_size = patch_size
202
- num_classes=5
203
- pred_labels = torch.zeros(1, num_classes+1, image.shape[2], image.shape[3]).cuda()
204
- counter_tensor = torch.zeros(1, 1, image.shape[2], image.shape[3]).cuda()
205
- for i in range(0, image.shape[2] - patch_size + 1, stride):
206
- for j in range(0, image.shape[3] - patch_size + 1, stride):
207
- i_lowered = min(i, image.shape[2] - patch_size)
208
- j_lowered = min(j, image.shape[3] - patch_size)
209
- patch = image[:, :, i_lowered:i_lowered + patch_size, j_lowered:j_lowered + patch_size]
210
- pred_labels_patch = model(patch.float())
211
- update_region_i = i_lowered + (patch_size - generator_output_size) // 2
212
- update_region_j = j_lowered + (patch_size - generator_output_size) // 2
213
- pred_labels[:, :, update_region_i:update_region_i + generator_output_size,
214
- update_region_j:update_region_j + generator_output_size] += pred_labels_patch
215
- counter_tensor[:, :, update_region_i:update_region_i + generator_output_size,
216
- update_region_j:update_region_j + generator_output_size] += 1
217
- pred_labels /= counter_tensor
218
- return pred_labels
219
-
220
- def segment_image(image):
221
- # img = Image.open(image_path)
222
- img = image
223
- img = np.asarray(img)
224
- if (np.max(img) > 100):
225
- img = img / 255.0
226
- transform = T.Compose([T.ToTensor()])
227
- image = transform(img)
228
- image = image[None, :]
229
- with torch.no_grad():
230
- pred_labels = predict_wsi(image.float())
231
- pred_labels = F.softmax(pred_labels, dim=1)
232
- pred_labels_probs = pred_labels.cpu().numpy()
233
- pred_labels = np.argmax(pred_labels_probs, axis=1)
234
- pred_labels = pred_labels[0]
235
- image = generate_colored_image(pred_labels)
236
- class_labels = ['tumor', 'stroma', 'inflammatory', 'necrosis', 'others']
237
- pixels_counts = []
238
- total=0
239
- print(np.unique(pred_labels))
240
- for i in range(1,len(class_labels)+1):
241
- current_count=np.sum(pred_labels == i)
242
- pixels_counts.append(current_count)
243
- total+=current_count
244
- pixels_counts = [(value / total) * 100 for value in pixels_counts]
245
- print(pixels_counts)
246
- plt.figure(figsize=(10, 6))
247
- bar_width = 0.15
248
- plt.bar(class_labels, pixels_counts, color='blue', width=bar_width)
249
- plt.xticks(rotation=45, ha='right')
250
- plt.xlabel('Tissue types', fontsize=17)
251
- plt.ylabel('Class Percentage', fontsize=17)
252
- plt.title('Classes distribution', fontsize=18)
253
- plt.xticks(fontsize=16)
254
- plt.yticks(fontsize=16)
255
- plt.tight_layout()
256
- with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmpfile:
257
- plt.savefig(tmpfile.name)
258
- temp_filename = tmpfile.name
259
- stats = Image.open(temp_filename)
260
-
261
- legend = Image.open('legend.png')
262
-
263
- # new_width = max(stats.width, legend.width)
264
- # new_height = stats.height + legend.height
265
- # new_image = Image.new("RGB", (new_width, new_height), (255, 255, 255))
266
- # new_image.paste(stats, (0, 0))
267
- # new_image.paste(legend, (image.height,0))
268
-
269
- return image, legend, stats
270
-
271
- model_path = './models/spaderesnet/spaderesnet16.pt'
272
- model = SPADEResNet(input_nc=3, output_nc=6)
273
- model = nn.DataParallel(model)
274
- model = model.cuda()
275
- model.load_state_dict(torch.load(model_path), strict=True)
276
-
277
- demo = gr.Interface(
278
- segment_image,
279
- inputs=gr.Image(),
280
- outputs=["image", "image", "image"],
281
- title="Breast Cancer Semantic Segmentation"
282
- )
283
-
284
- demo.launch()
 
 
1
+ import colorsys
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from PIL import Image
6
+ from metrics import *
7
+ import torchvision.transforms as T
8
+ import gradio as gr
9
+ import matplotlib.pyplot as plt
10
+ import tempfile
11
+
12
+ from huggingface_hub import snapshot_download
13
+
14
+ from huggingface_hub import login
15
+ login(token = os.getenv('HF_TOKEN'))
16
+
17
+ model_dir = snapshot_download(
18
+ repo_id="srijaydeshpande/spadesegresnet"
19
+ )
20
+
21
+ class SPADE(nn.Module):
22
+ def __init__(self, norm_nc, label_nc, norm):
23
+ super().__init__()
24
+
25
+
26
+ if norm == 'instance':
27
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
28
+ elif norm == 'batch':
29
+ self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
30
+
31
+ # The dimension of the intermediate embedding space. Yes, hardcoded.
32
+ nhidden = 128
33
+ ks = 3
34
+ pw = ks // 2
35
+ self.mlp_shared = nn.Sequential(
36
+ nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
37
+ nn.ReLU()
38
+ )
39
+ self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
40
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
41
+
42
+ def forward(self, x, segmap):
43
+
44
+ # Part 1. generate parameter-free normalized activations
45
+ normalized = self.param_free_norm(x)
46
+
47
+ # Part 2. produce scaling and bias conditioned on semantic map
48
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
49
+ actv = self.mlp_shared(segmap)
50
+ gamma = self.mlp_gamma(actv)
51
+ beta = self.mlp_beta(actv)
52
+
53
+ # apply scale and bias
54
+ out = normalized * (1 + gamma) + beta
55
+
56
+ return out
57
+
58
+ class SPADEResnetBlock(nn.Module):
59
+ def __init__(self, fin, fout):
60
+ super().__init__()
61
+ # Attributes
62
+ self.learned_shortcut = (fin != fout)
63
+ fmiddle = min(fin, fout)
64
+
65
+ # create conv layers
66
+ self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
67
+ self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
68
+ if self.learned_shortcut:
69
+ self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
70
+
71
+ # define normalization layers
72
+ self.norm_0 = SPADE(fin, 3, norm='instance')
73
+ self.norm_1 = SPADE(fmiddle, 3, norm='instance')
74
+ if self.learned_shortcut:
75
+ self.norm_s = SPADE(fin, 3, norm='instance')
76
+
77
+ def forward(self, x, seg):
78
+ x_s = self.shortcut(x, seg)
79
+
80
+ dx = self.conv_0(self.actvn(self.norm_0(x, seg)))
81
+ dx = self.conv_1(self.actvn(self.norm_1(dx, seg)))
82
+
83
+ out = x_s + dx
84
+
85
+ return out
86
+
87
+ def shortcut(self, x, seg):
88
+ if self.learned_shortcut:
89
+ x_s = self.conv_s(self.norm_s(x, seg))
90
+ else:
91
+ x_s = x
92
+ return x_s
93
+
94
+ def actvn(self, x):
95
+ return F.leaky_relu(x, 2e-1)
96
+
97
+ class ResnetBlock(nn.Module):
98
+
99
+ def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
100
+ super(ResnetBlock, self).__init__()
101
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)
102
+
103
+ def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
104
+ conv_block = []
105
+ p = 0
106
+ if padding_type == 'reflect':
107
+ conv_block += [nn.ReflectionPad2d(1)]
108
+ elif padding_type == 'replicate':
109
+ conv_block += [nn.ReplicationPad2d(1)]
110
+ elif padding_type == 'zero':
111
+ p = 1
112
+ else:
113
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
114
+
115
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
116
+ norm_layer(dim),
117
+ activation]
118
+
119
+ if use_dropout:
120
+ conv_block += [nn.Dropout(0.5)]
121
+
122
+ p = 0
123
+ if padding_type == 'reflect':
124
+ conv_block += [nn.ReflectionPad2d(1)]
125
+ elif padding_type == 'replicate':
126
+ conv_block += [nn.ReplicationPad2d(1)]
127
+ elif padding_type == 'zero':
128
+ p = 1
129
+ else:
130
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
131
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
132
+ norm_layer(dim)]
133
+
134
+ return nn.Sequential(*conv_block)
135
+
136
+ def forward(self, x):
137
+ out = x + self.conv_block(x)
138
+ return out
139
+
140
+ class SPADEResNet(torch.nn.Module):
141
+
142
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=5, norm_layer=nn.BatchNorm2d,
143
+ padding_type='reflect'):
144
+ assert (n_blocks >= 0)
145
+ super(SPADEResNet, self).__init__()
146
+ activation = nn.ReLU(True)
147
+
148
+ downsampler = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
149
+
150
+ ### downsample
151
+ for i in range(n_downsampling):
152
+ mult = 2 ** i
153
+ downsampler += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
154
+ norm_layer(ngf * mult * 2), activation]
155
+ self.downsampler = nn.Sequential(*downsampler)
156
+
157
+ ### resnet blocks
158
+ mult = 2 ** n_downsampling
159
+ self.resnetblocks1 = SPADEResnetBlock(ngf * mult, ngf * mult)
160
+ self.resnetblocks2 = SPADEResnetBlock(ngf * mult, ngf * mult)
161
+ self.resnetblocks3 = SPADEResnetBlock(ngf * mult, ngf * mult)
162
+ self.resnetblocks4 = SPADEResnetBlock(ngf * mult, ngf * mult)
163
+ self.resnetblocks5 = SPADEResnetBlock(ngf * mult, ngf * mult)
164
+
165
+ ### upsample
166
+ upsampler = []
167
+ for i in range(n_downsampling):
168
+ mult = 2 ** (n_downsampling - i)
169
+ upsampler += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1,
170
+ output_padding=1),
171
+ norm_layer(int(ngf * mult / 2)), activation]
172
+
173
+ upsampler += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
174
+
175
+ self.upsampler = nn.Sequential(*upsampler)
176
+
177
+ def forward(self, input):
178
+ downsampled = self.downsampler(input)
179
+ resnet1 = self.resnetblocks1(downsampled, input)
180
+ resnet2 = self.resnetblocks1(resnet1, input)
181
+ resnet3 = self.resnetblocks1(resnet2, input)
182
+ resnet4 = self.resnetblocks1(resnet3, input)
183
+ resnet5 = self.resnetblocks1(resnet4, input)
184
+ upsampled = self.upsampler(resnet5)
185
+ return upsampled
186
+
187
+ def generate_colors(n):
188
+ brightness = 0.7
189
+ hsv = [(i / n, 1, brightness) for i in range(n)]
190
+ colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
191
+ colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),colors))
192
+ return colors
193
+
194
+ def generate_colored_image(labels):
195
+ colors = generate_colors(6)
196
+ w, h = labels.shape
197
+ new_mk = np.empty([w, h, 3])
198
+ for i in range(0,w):
199
+ for j in range(0,h):
200
+ new_mk[i][j] = colors[labels[i][j]]
201
+ # new_mk = new_mk / 255.0
202
+ new_mk = new_mk.astype(np.uint8)
203
+ return Image.fromarray(new_mk)
204
+
205
+ def predict_wsi(image):
206
+ patch_size = 768
207
+ stride = 700 # stride is kept relatively lower than the tile size so as to allow some overlap while constructing bigger regions
208
+ generator_output_size = patch_size
209
+ num_classes=5
210
+ pred_labels = torch.zeros(1, num_classes+1, image.shape[2], image.shape[3]).cuda()
211
+ counter_tensor = torch.zeros(1, 1, image.shape[2], image.shape[3]).cuda()
212
+ for i in range(0, image.shape[2] - patch_size + 1, stride):
213
+ for j in range(0, image.shape[3] - patch_size + 1, stride):
214
+ i_lowered = min(i, image.shape[2] - patch_size)
215
+ j_lowered = min(j, image.shape[3] - patch_size)
216
+ patch = image[:, :, i_lowered:i_lowered + patch_size, j_lowered:j_lowered + patch_size]
217
+ pred_labels_patch = model(patch.float())
218
+ update_region_i = i_lowered + (patch_size - generator_output_size) // 2
219
+ update_region_j = j_lowered + (patch_size - generator_output_size) // 2
220
+ pred_labels[:, :, update_region_i:update_region_i + generator_output_size,
221
+ update_region_j:update_region_j + generator_output_size] += pred_labels_patch
222
+ counter_tensor[:, :, update_region_i:update_region_i + generator_output_size,
223
+ update_region_j:update_region_j + generator_output_size] += 1
224
+ pred_labels /= counter_tensor
225
+ return pred_labels
226
+
227
+ def segment_image(image):
228
+ # img = Image.open(image_path)
229
+ img = image
230
+ img = np.asarray(img)
231
+ if (np.max(img) > 100):
232
+ img = img / 255.0
233
+ transform = T.Compose([T.ToTensor()])
234
+ image = transform(img)
235
+ image = image[None, :]
236
+ with torch.no_grad():
237
+ pred_labels = predict_wsi(image.float())
238
+ pred_labels = F.softmax(pred_labels, dim=1)
239
+ pred_labels_probs = pred_labels.cpu().numpy()
240
+ pred_labels = np.argmax(pred_labels_probs, axis=1)
241
+ pred_labels = pred_labels[0]
242
+ image = generate_colored_image(pred_labels)
243
+ class_labels = ['tumor', 'stroma', 'inflammatory', 'necrosis', 'others']
244
+ pixels_counts = []
245
+ total=0
246
+ print(np.unique(pred_labels))
247
+ for i in range(1,len(class_labels)+1):
248
+ current_count=np.sum(pred_labels == i)
249
+ pixels_counts.append(current_count)
250
+ total+=current_count
251
+ pixels_counts = [(value / total) * 100 for value in pixels_counts]
252
+ print(pixels_counts)
253
+ plt.figure(figsize=(10, 6))
254
+ bar_width = 0.15
255
+ plt.bar(class_labels, pixels_counts, color='blue', width=bar_width)
256
+ plt.xticks(rotation=45, ha='right')
257
+ plt.xlabel('Tissue types', fontsize=17)
258
+ plt.ylabel('Class Percentage', fontsize=17)
259
+ plt.title('Classes distribution', fontsize=18)
260
+ plt.xticks(fontsize=16)
261
+ plt.yticks(fontsize=16)
262
+ plt.tight_layout()
263
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmpfile:
264
+ plt.savefig(tmpfile.name)
265
+ temp_filename = tmpfile.name
266
+ stats = Image.open(temp_filename)
267
+
268
+ legend = Image.open('legend.png')
269
+
270
+ return image, legend, stats
271
+
272
+ model_path = os.path.join(model_dir, 'spaderesnet.pt')
273
+ model = SPADEResNet(input_nc=3, output_nc=6)
274
+ model = nn.DataParallel(model)
275
+ model = model.cuda()
276
+ model.load_state_dict(torch.load(model_path), strict=True)
277
+
278
+ demo = gr.Interface(
279
+ segment_image,
280
+ inputs=gr.Image(),
281
+ outputs=["image", "image", "image"],
282
+ title="Breast Cancer Semantic Segmentation"
283
+ )
284
+
285
+ demo.launch()