piperod91 commited on
Commit
81d08ae
·
1 Parent(s): 5219e48

Explanations: align families with classifier top-5; accept target_labels from UI

Browse files
Files changed (2) hide show
  1. app.py +48 -6
  2. explanations.py +50 -9
app.py CHANGED
@@ -257,8 +257,34 @@ def generate_diagram_closest(input_image,model_name,top_k):
257
  diagram_path = get_diagram(embedding,top_k,model_name)
258
  return diagram_path
259
 
260
- def explain_image(input_image,model_name,explain_method,nb_samples,heatmap_alpha=0.22):
261
- model,n_classes= get_model(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  from explanations import explain
263
  if model_name in ('Fossils BEiT', 'Fossils 142'):
264
  size = 384
@@ -266,7 +292,19 @@ def explain_image(input_image,model_name,explain_method,nb_samples,heatmap_alpha
266
  size = 600
267
  #saliency, integrated, smoothgrad,
268
  h, w = input_image.shape[:2]
269
- classes,exp_list = explain(model,input_image, h, w, explain_method,nb_samples,size = size, n_classes=n_classes, heatmap_alpha=heatmap_alpha)
 
 
 
 
 
 
 
 
 
 
 
 
270
  #original = saliency + integrated + smoothgrad
271
  print('done')
272
 
@@ -959,8 +997,8 @@ with gr.Blocks(theme='sudeepshouche/minimalist', css=custom_css) as demo:
959
  # classify_button = gr.Button("Classify Image")
960
 
961
 
962
- def update_exp_outputs(input_image,model_name,explain_method,nb_samples,heatmap_alpha):
963
- labels, images = explain_image(input_image,model_name,explain_method,nb_samples,heatmap_alpha)
964
  #labels_html = "".join([f'<div style="display: inline-block; text-align: center; width: 18%;">{label}</div>' for label in labels])
965
  #labels_markdown = f"<div style='width: 100%; text-align: center;'>{labels_html}</div>"
966
  image_caption=[]
@@ -968,7 +1006,11 @@ with gr.Blocks(theme='sudeepshouche/minimalist', css=custom_css) as demo:
968
  image_caption.append((images[i],"Predicted Plant Family "+str(i)+": "+labels[i]))
969
  return image_caption
970
 
971
- generate_explanations.click(fn=update_exp_outputs, inputs=[original_image,model_name,explain_method,sampling_size,heatmap_alpha], outputs=[exp_gallery])
 
 
 
 
972
 
973
  #find_closest_btn.click(find_closest, inputs=[input_image,model_name], outputs=[label_closest_image_0,label_closest_image_1,label_closest_image_2,label_closest_image_3,label_closest_image_4,closest_image_0,closest_image_1,closest_image_2,closest_image_3,closest_image_4])
974
  def _closest_table_html(labels, images, filenames):
 
257
  diagram_path = get_diagram(embedding,top_k,model_name)
258
  return diagram_path
259
 
260
+ def _top_k_classes_from_label_output(label_output, k=5):
261
+ """
262
+ Extract the top-k class names from a Gradio Label output.
263
+ Supports dict[label -> confidence] or list[(label, confidence), ...].
264
+ """
265
+ if not label_output:
266
+ return []
267
+ # Dict: sort by confidence, descending
268
+ if isinstance(label_output, dict):
269
+ items = sorted(label_output.items(), key=lambda kv: kv[1], reverse=True)
270
+ return [name for name, _ in items[:k]]
271
+ # List: assume list of (label, confidence) or labels
272
+ if isinstance(label_output, list):
273
+ names = []
274
+ for item in label_output[:k]:
275
+ if isinstance(item, (list, tuple)) and len(item) >= 1:
276
+ names.append(item[0])
277
+ else:
278
+ names.append(item)
279
+ return names
280
+ # Single label string
281
+ if isinstance(label_output, str):
282
+ return [label_output]
283
+ return []
284
+
285
+
286
+ def explain_image(input_image, model_name, explain_method, nb_samples, class_predicted, heatmap_alpha=0.22):
287
+ model, n_classes = get_model(model_name)
288
  from explanations import explain
289
  if model_name in ('Fossils BEiT', 'Fossils 142'):
290
  size = 384
 
292
  size = 600
293
  #saliency, integrated, smoothgrad,
294
  h, w = input_image.shape[:2]
295
+ target_labels = _top_k_classes_from_label_output(class_predicted, k=5)
296
+ classes, exp_list = explain(
297
+ model,
298
+ input_image,
299
+ h,
300
+ w,
301
+ explain_method,
302
+ nb_samples,
303
+ size=size,
304
+ n_classes=n_classes,
305
+ heatmap_alpha=heatmap_alpha,
306
+ target_labels=target_labels or None,
307
+ )
308
  #original = saliency + integrated + smoothgrad
309
  print('done')
310
 
 
997
  # classify_button = gr.Button("Classify Image")
998
 
999
 
1000
+ def update_exp_outputs(input_image, model_name, explain_method, nb_samples, class_predicted, heatmap_alpha):
1001
+ labels, images = explain_image(input_image, model_name, explain_method, nb_samples, class_predicted, heatmap_alpha)
1002
  #labels_html = "".join([f'<div style="display: inline-block; text-align: center; width: 18%;">{label}</div>' for label in labels])
1003
  #labels_markdown = f"<div style='width: 100%; text-align: center;'>{labels_html}</div>"
1004
  image_caption=[]
 
1006
  image_caption.append((images[i],"Predicted Plant Family "+str(i)+": "+labels[i]))
1007
  return image_caption
1008
 
1009
+ generate_explanations.click(
1010
+ fn=update_exp_outputs,
1011
+ inputs=[original_image, model_name, explain_method, sampling_size, class_predicted, heatmap_alpha],
1012
+ outputs=[exp_gallery],
1013
+ )
1014
 
1015
  #find_closest_btn.click(find_closest, inputs=[input_image,model_name], outputs=[label_closest_image_0,label_closest_image_1,label_closest_image_2,label_closest_image_3,label_closest_image_4,closest_image_0,closest_image_1,closest_image_2,closest_image_3,closest_image_4])
1016
  def _closest_table_html(labels, images, filenames):
explanations.py CHANGED
@@ -1,8 +1,22 @@
1
  import xplique
2
  import tensorflow as tf
3
- from xplique.attributions import (Saliency, GradientInput, IntegratedGradients, SmoothGrad, VarGrad,
4
- SquareGrad, GradCAM, Occlusion, Rise, GuidedBackprop,
5
- GradCAMPP, Lime, KernelShap,SobolAttributionMethod,HsicAttributionMethod)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from xplique.attributions.global_sensitivity_analysis import LatinHypercube
7
  import numpy as np
8
  import matplotlib.pyplot as plt
@@ -10,6 +24,7 @@ from inference_resnet import inference_resnet_finer, preprocess
10
  from labels import lookup_140
11
  import cv2
12
  BATCH_SIZE = 1
 
13
 
14
 
15
  def letterbox_preprocess(img, size):
@@ -94,7 +109,18 @@ def show(img, original_size, output_size,p=False, **kwargs):
94
 
95
 
96
 
97
- def explain(model, input_image,h,w,explain_method,nb_samples,size=600, n_classes=171, heatmap_alpha=0.22) :
 
 
 
 
 
 
 
 
 
 
 
98
  """
99
  Generate explanations for a given model and dataset.
100
  :param model: The model to explain.
@@ -142,12 +168,27 @@ def explain(model, input_image,h,w,explain_method,nb_samples,size=600, n_classes
142
  content_mask = np.zeros((size, size), dtype=np.float32)
143
  content_mask[content_top : content_top + content_h, content_left : content_left + content_w] = 1.0
144
 
145
- predictions = class_model.predict(np.array([X]))
146
- #Y = np.argmax(predictions)
147
- top_5_indices = np.argsort(predictions[0])[-5:][::-1]
148
  classes = []
149
- for index in top_5_indices:
150
- classes.append(lookup_140[index])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  #print(top_5_indices)
152
  X = np.expand_dims(X, 0)
153
  explanations = []
 
1
  import xplique
2
  import tensorflow as tf
3
+ from xplique.attributions import (
4
+ Saliency,
5
+ GradientInput,
6
+ IntegratedGradients,
7
+ SmoothGrad,
8
+ VarGrad,
9
+ SquareGrad,
10
+ GradCAM,
11
+ Occlusion,
12
+ Rise,
13
+ GuidedBackprop,
14
+ GradCAMPP,
15
+ Lime,
16
+ KernelShap,
17
+ SobolAttributionMethod,
18
+ HsicAttributionMethod,
19
+ )
20
  from xplique.attributions.global_sensitivity_analysis import LatinHypercube
21
  import numpy as np
22
  import matplotlib.pyplot as plt
 
24
  from labels import lookup_140
25
  import cv2
26
  BATCH_SIZE = 1
27
+ _FAMILY_TO_INDEX = {v: k for k, v in lookup_140.items()}
28
 
29
 
30
  def letterbox_preprocess(img, size):
 
109
 
110
 
111
 
112
+ def explain(
113
+ model,
114
+ input_image,
115
+ h,
116
+ w,
117
+ explain_method,
118
+ nb_samples,
119
+ size=600,
120
+ n_classes=171,
121
+ heatmap_alpha=0.22,
122
+ target_labels=None,
123
+ ):
124
  """
125
  Generate explanations for a given model and dataset.
126
  :param model: The model to explain.
 
168
  content_mask = np.zeros((size, size), dtype=np.float32)
169
  content_mask[content_top : content_top + content_h, content_left : content_left + content_w] = 1.0
170
 
171
+ # Determine which classes to explain:
172
+ # - If target_labels are provided (from classifier output), use those.
173
+ # - Otherwise, fall back to top-5 classes from this forward pass.
174
  classes = []
175
+ if target_labels:
176
+ indices = []
177
+ for name in target_labels:
178
+ idx = _FAMILY_TO_INDEX.get(name)
179
+ if idx is not None:
180
+ indices.append(idx)
181
+ if indices:
182
+ top_5_indices = np.array(indices, dtype=int)
183
+ classes = [lookup_140[i] for i in top_5_indices]
184
+ else:
185
+ predictions = class_model.predict(np.array([X]))
186
+ top_5_indices = np.argsort(predictions[0])[-5:][::-1]
187
+ classes = [lookup_140[i] for i in top_5_indices]
188
+ else:
189
+ predictions = class_model.predict(np.array([X]))
190
+ top_5_indices = np.argsort(predictions[0])[-5:][::-1]
191
+ classes = [lookup_140[i] for i in top_5_indices]
192
  #print(top_5_indices)
193
  X = np.expand_dims(X, 0)
194
  explanations = []