MarcoParola commited on
Commit
01294f2
·
1 Parent(s): 201ab5d

save image id

Browse files
Files changed (2) hide show
  1. app.py +14 -11
  2. src/utils.py +1 -1
app.py CHANGED
@@ -29,6 +29,7 @@ def main():
29
  title = gr.Markdown("# Saliency evaluation - experiment 1")
30
  user_state = gr.State(0)
31
  answers = gr.State([])
 
32
  start_time = gr.State(time.time())
33
 
34
  concepts = load_csv_concepts(data_dir)
@@ -68,11 +69,12 @@ def main():
68
  with gr.Row():
69
  count = user_state if isinstance(user_state, int) else user_state.value
70
  images = load_image_and_saliency(count, data_dir)
71
- target_img = gr.Image(images[0], elem_classes="main-image delay", visible=False)
72
- saliency_gradcam = gr.Image(images[1], elem_classes="main-image", visible=False)
73
- saliency_lime = gr.Image(images[2], elem_classes="main-image", visible=False)
74
- saliency_sidu = gr.Image(images[4], elem_classes="main-image", visible=False)
75
- saliency_rise = gr.Image(images[3], elem_classes="main-image", visible=False)
 
76
 
77
 
78
  with gr.Row():
@@ -88,7 +90,7 @@ def main():
88
  def update_images(user_state):
89
  count = user_state if isinstance(user_state, int) else user_state.value
90
  if count < config['dataset'][config['dataset']['name']]['n_classes']:
91
- images = load_image_and_saliency(count, data_dir)
92
 
93
  # image examples
94
  images = load_example_images(count, data_dir)
@@ -112,18 +114,19 @@ def main():
112
  else:
113
  return img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16
114
 
115
- def update_saliencies(dropdown1, dropdown2, dropdown3, dropdown4, user_state):
116
  count = user_state if isinstance(user_state, int) else user_state.value
117
  if count < config['dataset'][config['dataset']['name']]['n_classes']:
118
  images = load_image_and_saliency(count, data_dir)
 
119
  target_img = gr.Image(images[0], elem_classes="main-image", visible=True)
120
  saliency_gradcam = gr.Image(images[1], elem_classes="main-image", visible=True)
121
  saliency_lime = gr.Image(images[2], elem_classes="main-image", visible=True)
122
  saliency_sidu = gr.Image(images[4], elem_classes="main-image", visible=True)
123
  saliency_rise = gr.Image(images[3], elem_classes="main-image", visible=True)
124
- return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu
125
  else:
126
- return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu
127
 
128
  def update_state(state):
129
  count = state if isinstance(state, int) else state.value
@@ -264,8 +267,8 @@ def main():
264
  outputs=target_img_label
265
  ).then(
266
  update_saliencies,
267
- inputs=[dropdown1, dropdown2, dropdown3, dropdown4, user_state],
268
- outputs={target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise},
269
  ).then(
270
  update_questions,
271
  inputs=user_state,
 
29
  title = gr.Markdown("# Saliency evaluation - experiment 1")
30
  user_state = gr.State(0)
31
  answers = gr.State([])
32
+ img_ids = gr.State([])
33
  start_time = gr.State(time.time())
34
 
35
  concepts = load_csv_concepts(data_dir)
 
69
  with gr.Row():
70
  count = user_state if isinstance(user_state, int) else user_state.value
71
  images = load_image_and_saliency(count, data_dir)
72
+ img_ids = gr.State([images[0]])
73
+ target_img = gr.Image(images[1], elem_classes="main-image delay", visible=False)
74
+ saliency_gradcam = gr.Image(images[2], elem_classes="main-image", visible=False)
75
+ saliency_lime = gr.Image(images[3], elem_classes="main-image", visible=False)
76
+ saliency_sidu = gr.Image(images[5], elem_classes="main-image", visible=False)
77
+ saliency_rise = gr.Image(images[4], elem_classes="main-image", visible=False)
78
 
79
 
80
  with gr.Row():
 
90
  def update_images(user_state):
91
  count = user_state if isinstance(user_state, int) else user_state.value
92
  if count < config['dataset'][config['dataset']['name']]['n_classes']:
93
+ #images = load_image_and_saliency(count, data_dir)
94
 
95
  # image examples
96
  images = load_example_images(count, data_dir)
 
114
  else:
115
  return img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16
116
 
117
+ def update_saliencies(dropdown1, dropdown2, dropdown3, dropdown4, user_state, img_ids):
118
  count = user_state if isinstance(user_state, int) else user_state.value
119
  if count < config['dataset'][config['dataset']['name']]['n_classes']:
120
  images = load_image_and_saliency(count, data_dir)
121
+ img_ids.append(images[0])
122
  target_img = gr.Image(images[0], elem_classes="main-image", visible=True)
123
  saliency_gradcam = gr.Image(images[1], elem_classes="main-image", visible=True)
124
  saliency_lime = gr.Image(images[2], elem_classes="main-image", visible=True)
125
  saliency_sidu = gr.Image(images[4], elem_classes="main-image", visible=True)
126
  saliency_rise = gr.Image(images[3], elem_classes="main-image", visible=True)
127
+ return img_ids, target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu
128
  else:
129
+ return img_ids, target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu
130
 
131
  def update_state(state):
132
  count = state if isinstance(state, int) else state.value
 
267
  outputs=target_img_label
268
  ).then(
269
  update_saliencies,
270
+ inputs=[dropdown1, dropdown2, dropdown3, dropdown4, user_state, img_ids],
271
+ outputs={img_ids, target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise},
272
  ).then(
273
  update_questions,
274
  inputs=user_state,
src/utils.py CHANGED
@@ -19,7 +19,7 @@ def load_image_and_saliency(class_idx, data_dir):
19
  lime_image = os.path.join(data_dir, 'saliency', 'lime', images[id])
20
  sidu_image = os.path.join(data_dir, 'saliency', 'sidu', images[id])
21
  rise_image = os.path.join(data_dir, 'saliency', 'rise', images[id])
22
- return image, gradcam_image, lime_image, sidu_image, rise_image
23
 
24
  def load_example_images(class_idx, data_dir, max_images=16):
25
  path = os.path.join(data_dir, 'images', str(class_idx))
 
19
  lime_image = os.path.join(data_dir, 'saliency', 'lime', images[id])
20
  sidu_image = os.path.join(data_dir, 'saliency', 'sidu', images[id])
21
  rise_image = os.path.join(data_dir, 'saliency', 'rise', images[id])
22
+ return id, image, gradcam_image, lime_image, sidu_image, rise_image
23
 
24
  def load_example_images(class_idx, data_dir, max_images=16):
25
  path = os.path.join(data_dir, 'images', str(class_idx))