Prathush21 commited on
Commit
6c97fab
·
verified ·
1 Parent(s): 666c899

Upload lrp_pipeline_2.py

Browse files
Files changed (1) hide show
  1. lrp_pipeline_2.py +417 -0
lrp_pipeline_2.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import torchvision
6
+ import os
7
+ import copy
8
+ from sklearn.mixture import GaussianMixture as GMM
9
+ from sklearn.cluster import KMeans
10
+ from simple_lama_inpainting import SimpleLama
11
+ from PIL import Image
12
+ from matplotlib.colors import ListedColormap
13
+ import matplotlib.pyplot as plt
14
+ import matplotlib
15
+ import csv
16
+
17
+ matplotlib.use("Agg")
18
+
19
+ import base64
20
+
21
+ from utils import (
22
+ select_sample_images,
23
+ create_cell_descriptors_table,
24
+ calculate_cell_descriptors,
25
+ )
26
+
27
+ preprocessed_folder = "uploads/"
28
+ intermediate_folder = "heatmaps/"
29
+ segmentation_folder = "segmentations/"
30
+ tables_folder = "tables/"
31
+ cell_descriptors_path = "cell_descriptors/cell_descriptors.csv"
32
+ imgclasses = {0: "abnormal", 1: "normal"}
33
+
34
+
35
+ def toconv(layers):
36
+ newlayers = []
37
+ for i, layer in enumerate(layers):
38
+ if isinstance(layer, nn.Linear):
39
+ newlayer = None
40
+ if i == 0:
41
+ m, n = 512, layer.weight.shape[0]
42
+ newlayer = nn.Conv2d(m, n, 4)
43
+ newlayer.weight = nn.Parameter(layer.weight.reshape(n, m, 4, 4))
44
+ else:
45
+ m, n = layer.weight.shape[1], layer.weight.shape[0]
46
+ newlayer = nn.Conv2d(m, n, 1)
47
+ newlayer.weight = nn.Parameter(layer.weight.reshape(n, m, 1, 1))
48
+ newlayer.bias = nn.Parameter(layer.bias)
49
+ newlayers += [newlayer]
50
+ else:
51
+ newlayers += [layer]
52
+ return newlayers
53
+
54
+
55
+ def newlayer(layer, g):
56
+ layer = copy.deepcopy(layer)
57
+ try:
58
+ layer.weight = nn.Parameter(g(layer.weight))
59
+ except AttributeError:
60
+ pass
61
+ try:
62
+ layer.bias = nn.Parameter(g(layer.bias))
63
+ except AttributeError:
64
+ pass
65
+ return layer
66
+
67
+
68
+ def heatmap(R, sx, sy, intermediate_path):
69
+ b = 10 * ((np.abs(R) ** 3.0).mean() ** (1.0 / 3))
70
+ my_cmap = plt.cm.seismic(np.arange(plt.cm.seismic.N))
71
+ my_cmap[:, 0:3] *= 0.85
72
+ my_cmap = ListedColormap(my_cmap)
73
+ plt.figure(figsize=(sx, sy))
74
+ plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
75
+ plt.axis("off")
76
+ plt.imshow(R, cmap=my_cmap, vmin=-b, vmax=b, interpolation="nearest")
77
+ # plt.show()
78
+ plt.savefig(intermediate_path, bbox_inches="tight", pad_inches=0)
79
+ plt.close()
80
+
81
+
82
+ def get_LRP_heatmap(image, L, layers, imgclasses, intermediate_path):
83
+ img = np.array(image)[..., ::-1] / 255.0
84
+ mean = torch.FloatTensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1) # torch.cuda
85
+ std = torch.FloatTensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1) # torch.cuda
86
+ X = (torch.FloatTensor(img[np.newaxis].transpose([0, 3, 1, 2]) * 1) - mean) / std
87
+
88
+ A = [X] + [None] * L
89
+ for l in range(L):
90
+ A[l + 1] = layers[l].forward(A[l])
91
+
92
+ scores = np.array(A[-1].cpu().data.view(-1))
93
+ ind = np.argsort(-scores)
94
+ for i in ind[:2]:
95
+ print("%20s (%3d): %6.3f" % (imgclasses[i], i, scores[i]))
96
+
97
+ T = torch.FloatTensor(
98
+ (1.0 * (np.arange(2) == ind[0]).reshape([1, 2, 1, 1]))
99
+ ) # SET FOR THE HIGHEST SCORE CLASS
100
+ R = [None] * L + [(A[-1] * T).data]
101
+ for l in range(1, L)[::-1]:
102
+ A[l] = (A[l].data).requires_grad_(True)
103
+ if isinstance(layers[l], torch.nn.MaxPool2d):
104
+ layers[l] = torch.nn.AvgPool2d(2)
105
+ if isinstance(layers[l], torch.nn.Conv2d) or isinstance(
106
+ layers[l], torch.nn.AvgPool2d
107
+ ):
108
+ rho = lambda p: p + 0.25 * p.clamp(min=0)
109
+ incr = lambda z: z + 1e-9 # USE ONLY THE GAMMA RULE FOR ALL LAYERS
110
+
111
+ z = incr(newlayer(layers[l], rho).forward(A[l])) # step 1
112
+ # adding epsilon
113
+ epsilon = 1e-9
114
+ z_nonzero = torch.where(z == 0, torch.tensor(epsilon, device=z.device), z)
115
+ s = (R[l + 1] / z_nonzero).data
116
+ # s = (R[l+1]/z).data # step 2
117
+ (z * s).sum().backward()
118
+ c = A[l].grad # step 3
119
+ R[l] = (A[l] * c).data # step 4
120
+ else:
121
+ R[l] = R[l + 1]
122
+
123
+ A[0] = (A[0].data).requires_grad_(True)
124
+ lb = (A[0].data * 0 + (0 - mean) / std).requires_grad_(True)
125
+ hb = (A[0].data * 0 + (1 - mean) / std).requires_grad_(True)
126
+
127
+ z = layers[0].forward(A[0]) + 1e-9 # step 1 (a)
128
+ z -= newlayer(layers[0], lambda p: p.clamp(min=0)).forward(lb) # step 1 (b)
129
+ z -= newlayer(layers[0], lambda p: p.clamp(max=0)).forward(hb) # step 1 (c)
130
+
131
+ # adding epsilon
132
+ epsilon = 1e-9
133
+ z_nonzero = torch.where(z == 0, torch.tensor(epsilon, device=z.device), z)
134
+ s = (R[1] / z_nonzero).data # step 2
135
+
136
+ (z * s).sum().backward()
137
+ c, cp, cm = A[0].grad, lb.grad, hb.grad # step 3
138
+ R[0] = (A[0] * c + lb * cp + hb * cm).data # step 4
139
+ heatmap(
140
+ np.array(R[0][0].cpu()).sum(axis=0), 2, 2, intermediate_path
141
+ ) # HEATMAPPING TO SEE LRP MAPS WITH NEW RULE
142
+ return R[0][0].cpu()
143
+
144
+
145
+ def get_nucleus_mask_for_graphcut(R):
146
+ res = np.array(R).sum(axis=0)
147
+ # Reshape the data to a 1D array
148
+ data_1d = res.flatten().reshape(-1, 1)
149
+ n_clusters = 2
150
+ kmeans = KMeans(n_clusters=n_clusters, random_state=0)
151
+ # kmeans.fit(data_1d)
152
+ kmeans.fit(data_1d)
153
+ # Step 4: Assign data points to clusters
154
+ cluster_assignments = kmeans.labels_
155
+ # Step 5: Reshape cluster assignments into a 2D binary matrix
156
+ binary_matrix = cluster_assignments.reshape(128, 128)
157
+ # Now, binary_matrix contains 0s and 1s, separating the data into two classes using K-Means clustering
158
+ rel_grouping = np.zeros((128, 128, 3), dtype=np.uint8)
159
+ rel_grouping[binary_matrix == 1] = [255, 0, 0] # Main object (Blue)
160
+ rel_grouping[binary_matrix == 2] = [128, 0, 0] # Second label (Dark Blue)
161
+ rel_grouping[binary_matrix == 0] = [0, 0, 255] # Background (Red)
162
+ return rel_grouping
163
+
164
+
165
+ def segment_nucleus(image, rel_grouping): # clustered = rel_grouping
166
+
167
+ # GET THE BOUNDING BOX FROM CLUSTERED
168
+ blue_pixels = np.sum(np.all(rel_grouping == [255, 0, 0], axis=-1))
169
+ red_pixels = np.sum(np.all(rel_grouping == [0, 0, 255], axis=-1))
170
+ if red_pixels > blue_pixels:
171
+ color = np.array([255, 0, 0])
172
+ else:
173
+ color = np.array([0, 0, 255])
174
+ mask = cv2.inRange(rel_grouping, color, color)
175
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
176
+ contour_areas = []
177
+ for contour in contours:
178
+ x, y, w, h = cv2.boundingRect(contour)
179
+ contour_areas.append(cv2.contourArea(contour))
180
+ contour_areas.sort()
181
+ contour_areas = np.array(contour_areas)
182
+ quartile_50 = np.percentile(contour_areas, 50)
183
+ selected_contours = [
184
+ contour for contour in contours if cv2.contourArea(contour) >= quartile_50
185
+ ]
186
+ x, y, w, h = cv2.boundingRect(np.concatenate(selected_contours))
187
+
188
+ # APPLY GRABCUT
189
+ fgModel = np.zeros((1, 65), dtype="float")
190
+ bgModel = np.zeros((1, 65), dtype="float")
191
+ mask = np.zeros(image.shape[:2], np.uint8)
192
+ rect = (x, y, x + w, y + h)
193
+
194
+ # IF BOUNDING BOX IS THE WHOLE IMAGE, THEN BOUNDING BOX METHOD WONT'T WORK -> SO USE INIT WITH MASK METHOD ITSELF
195
+ if (x, y, x + w, y + h) == (0, 0, 128, 128):
196
+
197
+ if (
198
+ red_pixels > blue_pixels
199
+ ): # red is the dominant color and thus the background
200
+ mask[(rel_grouping == [255, 0, 0]).all(axis=2)] = (
201
+ cv2.GC_PR_FGD
202
+ ) # Probable Foreground
203
+ mask[(rel_grouping == [0, 0, 255]).all(axis=2)] = (
204
+ cv2.GC_PR_BGD
205
+ ) # Probable Background
206
+ else: # blue is the dominant color and thus the background
207
+ mask[(rel_grouping == [0, 0, 255]).all(axis=2)] = (
208
+ cv2.GC_PR_FGD
209
+ ) # Probable Foreground
210
+ mask[(rel_grouping == [255, 0, 0]).all(axis=2)] = (
211
+ cv2.GC_PR_BGD
212
+ ) # Probable Background
213
+
214
+ (mask, bgModel, fgModel) = cv2.grabCut(
215
+ image,
216
+ mask,
217
+ rect,
218
+ bgModel,
219
+ fgModel,
220
+ iterCount=10,
221
+ mode=cv2.GC_INIT_WITH_MASK,
222
+ )
223
+
224
+ # ELSE PASS THE BOUNDING BOX FOR GRABCUT
225
+ else:
226
+ (mask, bgModel, fgModel) = cv2.grabCut(
227
+ image,
228
+ mask,
229
+ rect,
230
+ bgModel,
231
+ fgModel,
232
+ iterCount=10,
233
+ mode=cv2.GC_INIT_WITH_RECT,
234
+ )
235
+
236
+ # FORM THE COLORED SEGMENTATION MASK
237
+ clean_binary_mask = np.where(
238
+ (mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 1, 0
239
+ ).astype("uint8")
240
+ nucleus_segment = np.zeros((128, 128, 3), dtype=np.uint8)
241
+ nucleus_segment[clean_binary_mask == 1] = [255, 0, 0] # Main object (Blue)
242
+ nucleus_segment[clean_binary_mask == 0] = [0, 0, 255] # Background (Red)
243
+ return nucleus_segment, clean_binary_mask
244
+
245
+
246
+ def remove_nucleus(image1, blue_mask1, simple_lama): # image, blue_mask, x, y
247
+ # expand the nucleus mask
248
+ # image1 = cv2.resize(image, (128,128))
249
+ # blue_mask1 = cv2.resize(blue_mask, (128,128))
250
+ kernel = np.ones((5, 5), np.uint8) # Adjust the kernel size as needed
251
+ expandedmask = cv2.dilate(blue_mask1, kernel, iterations=1)
252
+ image_pil = Image.fromarray(cv2.cvtColor(image1, cv2.COLOR_BGR2RGB))
253
+ mask_pil = Image.fromarray(expandedmask)
254
+ result = simple_lama(image_pil, mask_pil)
255
+ result_cv2 = np.array(result)
256
+ result_cv2 = cv2.cvtColor(result_cv2, cv2.COLOR_RGB2BGR)
257
+ # result_cv2 = cv2.resize(result_cv2, (x,y))
258
+ return expandedmask, result_cv2
259
+
260
+
261
+ def get_final_mask(nucleus_removed_img, blue_mask, expanded_mask):
262
+ # apply graphcut - init with rectangle (not mask approximation mask)
263
+ fgModel = np.zeros((1, 65), dtype="float")
264
+ bgModel = np.zeros((1, 65), dtype="float")
265
+
266
+ rect = (1, 1, nucleus_removed_img.shape[1], nucleus_removed_img.shape[0])
267
+
268
+ (mask, bgModel, fgModel) = cv2.grabCut(
269
+ nucleus_removed_img,
270
+ expanded_mask,
271
+ rect,
272
+ bgModel,
273
+ fgModel,
274
+ iterCount=20,
275
+ mode=cv2.GC_INIT_WITH_RECT,
276
+ )
277
+
278
+ clean_binary_mask = np.where(
279
+ (mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 1, 0
280
+ ).astype("uint8")
281
+ colored_segmentation_mask = np.zeros((128, 128, 3), dtype=np.uint8)
282
+ colored_segmentation_mask[clean_binary_mask == 1] = [
283
+ 128,
284
+ 0,
285
+ 0,
286
+ ] # Main object (Blue)
287
+ colored_segmentation_mask[clean_binary_mask == 0] = [0, 0, 255] # Background (Red)
288
+ colored_segmentation_mask[blue_mask > 0] = [255, 0, 0]
289
+ return colored_segmentation_mask
290
+
291
+
292
+ def lrp_main(pixel_conversion):
293
+ i = 0
294
+ return_dict_count = 1
295
+ return_dict = {}
296
+ selected_indices = select_sample_images()
297
+ resized_shape = (128, 128)
298
+ cell_descriptors = [
299
+ ["Image Name", "Nucleus Area", "Cytoplasm Area", "Nucleus to Cytoplasm Ratio"]
300
+ ]
301
+
302
+ # MODEL SECTION STARTS FOR NEW MODEL
303
+ vgg16 = torchvision.models.vgg16(pretrained=True)
304
+ new_avgpool = nn.AdaptiveAvgPool2d(output_size=(4, 4))
305
+ vgg16.avgpool = new_avgpool
306
+ classifier_list = [
307
+ nn.Linear(8192, vgg16.classifier[0].out_features)
308
+ ] # vgg16.classifier[0].out_features = 4096
309
+ classifier_list += list(vgg16.classifier.children())[
310
+ 1:-1
311
+ ] # Remove the first and last layers
312
+ classifier_list += [
313
+ nn.Linear(vgg16.classifier[6].in_features, 2)
314
+ ] # vgg16.classifier[6].in_features = 4096
315
+ vgg16.classifier = nn.Sequential(
316
+ *classifier_list
317
+ ) # Replace the model classifier
318
+
319
+ PATH = "herlev_best_adam_vgg16_modified12_final.pth"
320
+ checkpoint = torch.load(PATH, map_location=torch.device("cpu"))
321
+ vgg16.load_state_dict(checkpoint)
322
+ # vgg16.to(torch.device('cuda'))
323
+ vgg16.eval()
324
+
325
+ layers = list(vgg16._modules["features"]) + toconv(
326
+ list(vgg16._modules["classifier"])
327
+ )
328
+ L = len(layers)
329
+ # MODEL SECTION ENDS
330
+
331
+ simple_lama = SimpleLama()
332
+
333
+ for imagefile in os.listdir(preprocessed_folder):
334
+ if (
335
+ "MACOSX".lower() in imagefile.lower()
336
+ or "." == imagefile[0]
337
+ or "_" == imagefile[0]
338
+ ):
339
+ print(imagefile)
340
+ continue
341
+ image_path = (
342
+ preprocessed_folder + os.path.splitext(imagefile)[0].lower() + ".png"
343
+ )
344
+ intermediate_path = (
345
+ intermediate_folder
346
+ + os.path.splitext(imagefile)[0].lower()
347
+ + "_heatmap.png"
348
+ )
349
+ save_path = (
350
+ segmentation_folder + os.path.splitext(imagefile)[0].lower() + "_mask.png"
351
+ )
352
+ table_path = (
353
+ tables_folder + os.path.splitext(imagefile)[0].lower() + "_table.png"
354
+ )
355
+
356
+ # print(i, imagefile)
357
+ image = cv2.imread(image_path)
358
+ original_shape = image.shape
359
+
360
+ image = cv2.resize(image, (128, 128))
361
+
362
+ layers_copy = copy.deepcopy(layers)
363
+ R = get_LRP_heatmap(image, L, layers_copy, imgclasses, intermediate_path)
364
+
365
+ rel_grouping = get_nucleus_mask_for_graphcut(R)
366
+
367
+ nucleus_segment, clean_binary_mask = segment_nucleus(image, rel_grouping)
368
+
369
+ expanded_mask, nucleus_removed_image = remove_nucleus(image, clean_binary_mask, simple_lama)
370
+
371
+ colored_segmentation_mask = get_final_mask(
372
+ nucleus_removed_image, clean_binary_mask, expanded_mask
373
+ )
374
+
375
+ cv2.imwrite(save_path, colored_segmentation_mask)
376
+
377
+ nucleus_area, cytoplasm_area, ratio = calculate_cell_descriptors(
378
+ original_shape, resized_shape, pixel_conversion, colored_segmentation_mask
379
+ )
380
+ cell_descriptors.append(
381
+ [
382
+ os.path.splitext(imagefile)[0].lower(),
383
+ nucleus_area,
384
+ cytoplasm_area,
385
+ ratio,
386
+ ]
387
+ )
388
+
389
+ create_cell_descriptors_table(table_path, nucleus_area, cytoplasm_area, ratio)
390
+
391
+ if i in selected_indices:
392
+ return_dict[f"image{return_dict_count}"] = str(
393
+ base64.b64encode(open(image_path, "rb").read()).decode("utf-8")
394
+ )
395
+ return_dict[f"inter{return_dict_count}"] = str(
396
+ base64.b64encode(open(intermediate_path, "rb").read()).decode("utf-8")
397
+ )
398
+ return_dict[f"mask{return_dict_count}"] = str(
399
+ base64.b64encode(open(save_path, "rb").read()).decode("utf-8")
400
+ )
401
+ return_dict[f"table{return_dict_count}"] = str(
402
+ base64.b64encode(open(table_path, "rb").read()).decode("utf-8")
403
+ )
404
+ return_dict_count += 1
405
+
406
+ i += 1
407
+
408
+ # Visualization
409
+ # for im in [image, gt2, rel_grouping, nucleus_segment, clean_binary_mask*255, nucleus_removed_image, colored_segmentation_mask]:
410
+ # cv2_imshow(im)
411
+
412
+ # write cell_descriptors list to csv file
413
+ with open(cell_descriptors_path, "w", newline="") as csv_file:
414
+ writer = csv.writer(csv_file)
415
+ writer.writerows(cell_descriptors)
416
+
417
+ return return_dict