hp733 commited on
Commit
8b1b4c7
·
verified ·
1 Parent(s): 3e47932

Update gradcam_utils.py

Browse files
Files changed (1) hide show
  1. gradcam_utils.py +99 -64
gradcam_utils.py CHANGED
@@ -77,68 +77,103 @@
77
 
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  import numpy as np
81
- import cv2
82
- import tensorflow as tf
83
- from tensorflow.keras.preprocessing.image import img_to_array, load_img
84
- from matplotlib.colors import LinearSegmentedColormap
85
-
86
- def preprocess_image(img_path, target_size):
87
- img = load_img(img_path, target_size=target_size)
88
- img = img_to_array(img)
89
- img = np.expand_dims(img, axis=0)
90
- img = img / 255.0
91
- return img
92
-
93
- def make_gradcam_heatmap(model, img_tensor):
94
- grad_model = tf.keras.models.Model([model.input], [model.output])
95
-
96
- with tf.GradientTape() as tape:
97
- conv_outputs = model(img_tensor)
98
- loss = conv_outputs[:, 1] # class index 1 = pneumonia
99
-
100
- grads = tape.gradient(loss, conv_outputs)
101
- pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
102
-
103
- conv_outputs = conv_outputs[0]
104
- heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
105
- heatmap = tf.squeeze(heatmap)
106
- heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
107
- heatmap = tf.where(tf.math.is_nan(heatmap), tf.zeros_like(heatmap), heatmap)
108
- return heatmap.numpy()
109
-
110
- def create_custom_colormap():
111
- colors = ['blue', 'green', 'yellow', 'red']
112
- cmap = LinearSegmentedColormap.from_list('custom', colors, N=256)
113
- return cmap
114
-
115
- def apply_custom_colormap(heatmap, cmap):
116
- colored_heatmap = cmap(heatmap)
117
- return np.uint8(colored_heatmap * 255)
118
-
119
- def enhance_heatmap(heatmap, gamma=0.7, percentile=99):
120
- heatmap = np.power(heatmap, gamma)
121
- heatmap = heatmap / np.percentile(heatmap, percentile)
122
- return np.clip(heatmap, 0, 1)
123
-
124
- def generate_and_merge_heatmaps(img_path, vgg_model, efficientnet_model, densenet_model, img_size=(224, 224)):
125
- img_tensor = preprocess_image(img_path, img_size)
126
-
127
- vgg_heatmap = make_gradcam_heatmap(vgg_model, img_tensor)
128
- efficientnet_heatmap = make_gradcam_heatmap(efficientnet_model, img_tensor)
129
- densenet_heatmap = make_gradcam_heatmap(densenet_model, img_tensor)
130
-
131
- vgg_heatmap = cv2.resize(vgg_heatmap, img_size)
132
- efficientnet_heatmap = cv2.resize(efficientnet_heatmap, img_size)
133
- densenet_heatmap = cv2.resize(densenet_heatmap, img_size)
134
-
135
- merged = (vgg_heatmap + efficientnet_heatmap + densenet_heatmap) / 3.0
136
- enhanced = enhance_heatmap(merged)
137
- colored = apply_custom_colormap(enhanced, create_custom_colormap())
138
-
139
- original = cv2.imread(img_path)
140
- original = cv2.resize(original, img_size)
141
- original = cv2.cvtColor(original, cv2.COLOR_BGR2RGB)
142
-
143
- superimposed_img = cv2.addWeighted(original, 0.6, colored[:, :, :3], 0.4, 0)
144
- return superimposed_img
 
77
 
78
 
79
 
80
+ # import numpy as np
81
+ # import cv2
82
+ # import tensorflow as tf
83
+ # from tensorflow.keras.preprocessing.image import img_to_array, load_img
84
+ # from matplotlib.colors import LinearSegmentedColormap
85
+
86
+ # def preprocess_image(img_path, target_size):
87
+ # img = load_img(img_path, target_size=target_size)
88
+ # img = img_to_array(img)
89
+ # img = np.expand_dims(img, axis=0)
90
+ # img = img / 255.0
91
+ # return img
92
+
93
+ # def make_gradcam_heatmap(model, img_tensor):
94
+ # grad_model = tf.keras.models.Model([model.input], [model.output])
95
+
96
+ # with tf.GradientTape() as tape:
97
+ # conv_outputs = model(img_tensor)
98
+ # loss = conv_outputs[:, 1] # class index 1 = pneumonia
99
+
100
+ # grads = tape.gradient(loss, conv_outputs)
101
+ # pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
102
+
103
+ # conv_outputs = conv_outputs[0]
104
+ # heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
105
+ # heatmap = tf.squeeze(heatmap)
106
+ # heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
107
+ # heatmap = tf.where(tf.math.is_nan(heatmap), tf.zeros_like(heatmap), heatmap)
108
+ # return heatmap.numpy()
109
+
110
+ # def create_custom_colormap():
111
+ # colors = ['blue', 'green', 'yellow', 'red']
112
+ # cmap = LinearSegmentedColormap.from_list('custom', colors, N=256)
113
+ # return cmap
114
+
115
+ # def apply_custom_colormap(heatmap, cmap):
116
+ # colored_heatmap = cmap(heatmap)
117
+ # return np.uint8(colored_heatmap * 255)
118
+
119
+ # def enhance_heatmap(heatmap, gamma=0.7, percentile=99):
120
+ # heatmap = np.power(heatmap, gamma)
121
+ # heatmap = heatmap / np.percentile(heatmap, percentile)
122
+ # return np.clip(heatmap, 0, 1)
123
+
124
+ # def generate_and_merge_heatmaps(img_path, vgg_model, efficientnet_model, densenet_model, img_size=(224, 224)):
125
+ # img_tensor = preprocess_image(img_path, img_size)
126
+
127
+ # vgg_heatmap = make_gradcam_heatmap(vgg_model, img_tensor)
128
+ # efficientnet_heatmap = make_gradcam_heatmap(efficientnet_model, img_tensor)
129
+ # densenet_heatmap = make_gradcam_heatmap(densenet_model, img_tensor)
130
+
131
+ # vgg_heatmap = cv2.resize(vgg_heatmap, img_size)
132
+ # efficientnet_heatmap = cv2.resize(efficientnet_heatmap, img_size)
133
+ # densenet_heatmap = cv2.resize(densenet_heatmap, img_size)
134
+
135
+ # merged = (vgg_heatmap + efficientnet_heatmap + densenet_heatmap) / 3.0
136
+ # enhanced = enhance_heatmap(merged)
137
+ # colored = apply_custom_colormap(enhanced, create_custom_colormap())
138
+
139
+ # original = cv2.imread(img_path)
140
+ # original = cv2.resize(original, img_size)
141
+ # original = cv2.cvtColor(original, cv2.COLOR_BGR2RGB)
142
+
143
+ # superimposed_img = cv2.addWeighted(original, 0.6, colored[:, :, :3], 0.4, 0)
144
+ # return superimposed_img
145
+
146
+
147
+
148
+
149
+
150
+ from tf_explain.core.grad_cam import GradCAM
151
  import numpy as np
152
+ from PIL import Image
153
+
154
+ def generate_heatmap_tf_explain(image_pil, model, class_index):
155
+ """
156
+ Generates a Grad-CAM heatmap using tf-explain and overlays it on the original image.
157
+
158
+ Parameters:
159
+ image_pil (PIL.Image): Input chest X-ray image.
160
+ model (tf.keras.Model): CNN model for explanation (e.g. VGG19).
161
+ class_index (int): Index of the predicted class (0 or 1).
162
+
163
+ Returns:
164
+ heatmap_image (PIL.Image): Heatmap image overlaid on original image.
165
+ """
166
+ # Resize and preprocess image
167
+ img_array = np.array(image_pil.resize((224, 224))) / 255.0
168
+ img_array = np.expand_dims(img_array, axis=0)
169
+
170
+ # Generate Grad-CAM explanation
171
+ explainer = GradCAM()
172
+ explanation = explainer.explain(
173
+ validation_data=(img_array, None),
174
+ model=model,
175
+ class_index=class_index
176
+ )
177
+
178
+ return Image.fromarray(explanation)
179
+