hp733 commited on
Commit
cd8c3af
·
verified ·
1 Parent(s): 2f1c285

Update gradcam_utils.py

Browse files
Files changed (1) hide show
  1. gradcam_utils.py +17 -13
gradcam_utils.py CHANGED
@@ -147,28 +147,32 @@
147
 
148
 
149
 
150
- from tf_explain.core.grad_cam import GradCAM
151
  import numpy as np
 
 
 
152
  from PIL import Image
153
 
154
- def generate_heatmap_tf_explain(image_pil, model, class_index, layer_name="block5_conv4"):
155
- from tf_explain.core.grad_cam import GradCAM
156
-
157
- # Preprocess image
158
- img_array = np.array(image_pil.resize((224, 224))) / 255.0
159
  img_array = np.expand_dims(img_array, axis=0)
160
 
161
- # Reconstruct model to include target layer
162
- from tensorflow.keras.models import Model
163
- model_for_explanation = Model(inputs=model.input, outputs=model.output)
164
 
 
165
  explainer = GradCAM()
166
  explanation = explainer.explain(
167
- validation_data=(img_array, None),
168
- model=model_for_explanation,
169
  class_index=class_index,
170
- layer_name=layer_name
171
  )
172
 
173
- return Image.fromarray(explanation)
 
 
 
174
 
 
147
 
148
 
149
 
 
150
  import numpy as np
151
+ from tf_explain.core.grad_cam import GradCAM
152
+ import tensorflow as tf
153
+ from tensorflow.keras.preprocessing.image import img_to_array
154
  from PIL import Image
155
 
156
+ def generate_heatmap_tf_explain(image_pil, model, class_index):
157
+ # Convert to numpy and preprocess
158
+ img = image_pil.resize((224, 224))
159
+ img_array = img_to_array(img) / 255.0
 
160
  img_array = np.expand_dims(img_array, axis=0)
161
 
162
+ # Prepare data
163
+ data = ([img_array], None)
 
164
 
165
+ # Run tf-explain GradCAM
166
  explainer = GradCAM()
167
  explanation = explainer.explain(
168
+ validation_data=data,
169
+ model=model,
170
  class_index=class_index,
171
+ layer_name=None # will auto-detect final conv layer
172
  )
173
 
174
+ # Convert explanation (numpy) to PIL
175
+ heatmap_pil = Image.fromarray(explanation).resize((224, 224))
176
+ return heatmap_pil
177
+
178