hp733 commited on
Commit
695fb8b
·
verified ·
1 Parent(s): cd8c3af

Update gradcam_utils.py

Browse files
Files changed (1) hide show
  1. gradcam_utils.py +13 -13
gradcam_utils.py CHANGED
@@ -153,26 +153,26 @@ import tensorflow as tf
153
  from tensorflow.keras.preprocessing.image import img_to_array
154
  from PIL import Image
155
 
156
- def generate_heatmap_tf_explain(image_pil, model, class_index):
157
- # Convert to numpy and preprocess
158
- img = image_pil.resize((224, 224))
159
- img_array = img_to_array(img) / 255.0
 
160
  img_array = np.expand_dims(img_array, axis=0)
161
 
162
- # Prepare data
163
- data = ([img_array], None)
 
164
 
165
- # Run tf-explain GradCAM
166
  explainer = GradCAM()
167
  explanation = explainer.explain(
168
- validation_data=data,
169
- model=model,
170
  class_index=class_index,
171
- layer_name=None # will auto-detect final conv layer
172
  )
173
 
174
- # Convert explanation (numpy) to PIL
175
- heatmap_pil = Image.fromarray(explanation).resize((224, 224))
176
- return heatmap_pil
177
 
178
 
 
153
  from tensorflow.keras.preprocessing.image import img_to_array
154
  from PIL import Image
155
 
156
+ def generate_heatmap_tf_explain(image_pil, model, class_index, layer_name="block5_conv4"):
157
+ from tf_explain.core.grad_cam import GradCAM
158
+
159
+ # Preprocess image
160
+ img_array = np.array(image_pil.resize((224, 224))) / 255.0
161
  img_array = np.expand_dims(img_array, axis=0)
162
 
163
+ # Reconstruct model to include target layer
164
+ from tensorflow.keras.models import Model
165
+ model_for_explanation = Model(inputs=model.input, outputs=model.output)
166
 
 
167
  explainer = GradCAM()
168
  explanation = explainer.explain(
169
+ validation_data=(img_array, None),
170
+ model=model_for_explanation,
171
  class_index=class_index,
172
+ layer_name=layer_name
173
  )
174
 
175
+ return Image.fromarray(explanation)
176
+
 
177
 
178