hp733 commited on
Commit
87cdfd2
·
verified ·
1 Parent(s): c6fc64d

Update gradcam_utils.py

Browse files
Files changed (1) hide show
  1. gradcam_utils.py +11 -16
gradcam_utils.py CHANGED
@@ -151,28 +151,23 @@ from tf_explain.core.grad_cam import GradCAM
151
  import numpy as np
152
  from PIL import Image
153
 
154
- def generate_heatmap_tf_explain(image_pil, model, class_index):
155
- """
156
- Generates a Grad-CAM heatmap using tf-explain and overlays it on the original image.
157
-
158
- Parameters:
159
- image_pil (PIL.Image): Input chest X-ray image.
160
- model (tf.keras.Model): CNN model for explanation (e.g. VGG19).
161
- class_index (int): Index of the predicted class (0 or 1).
162
-
163
- Returns:
164
- heatmap_image (PIL.Image): Heatmap image overlaid on original image.
165
- """
166
- # Resize and preprocess image
167
  img_array = np.array(image_pil.resize((224, 224))) / 255.0
168
  img_array = np.expand_dims(img_array, axis=0)
169
 
170
- # Generate Grad-CAM explanation
 
 
 
171
  explainer = GradCAM()
172
  explanation = explainer.explain(
173
  validation_data=(img_array, None),
174
- model=model,
175
- class_index=class_index
 
176
  )
177
 
178
  return Image.fromarray(explanation)
 
151
  import numpy as np
152
  from PIL import Image
153
 
154
+ def generate_heatmap_tf_explain(image_pil, model, class_index, layer_name="block5_conv4"):
155
+ from tf_explain.core.grad_cam import GradCAM
156
+
157
+ # Preprocess image
 
 
 
 
 
 
 
 
 
158
  img_array = np.array(image_pil.resize((224, 224))) / 255.0
159
  img_array = np.expand_dims(img_array, axis=0)
160
 
161
+ # Reconstruct model to include target layer
162
+ from tensorflow.keras.models import Model
163
+ model_for_explanation = Model(inputs=model.input, outputs=model.output)
164
+
165
  explainer = GradCAM()
166
  explanation = explainer.explain(
167
  validation_data=(img_array, None),
168
+ model=model_for_explanation,
169
+ class_index=class_index,
170
+ layer_name=layer_name
171
  )
172
 
173
  return Image.fromarray(explanation)