MarcoParola commited on
Commit
201ab5d
·
1 Parent(s): 97c0fe2

first commit

Browse files
.gitignore ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /env
2
+ __pycache__/
3
+
4
+ /logs
5
+ /outputs
6
+ /.hydra
7
+ /checkpoints
8
+ /wandb
9
+ /models
10
+ /share
11
+ /bin
12
+ /lib
13
+ /lib64
14
+ /include
15
+ pyvenv.cfg
16
+ requirements.txt
17
+
18
+ *.log
19
+ *.pth
20
+ *.png
21
+
22
+
23
+ /lightning_logs
24
+ __pycache__/
app.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import yaml
3
+ import random
4
+ import os
5
+ import json
6
+ import time
7
+ from pathlib import Path
8
+ from huggingface_hub import CommitScheduler, HfApi
9
+
10
+ from src.utils import load_words, load_image_and_saliency, load_example_images, load_csv_concepts
11
+ from src.style import css
12
+ from src.user import UserID
13
+
14
+ from datetime import datetime
15
+ from pathlib import Path
16
+ from uuid import uuid4
17
+ import json
18
+ from huggingface_hub import CommitScheduler
19
+
20
+ def main():
21
+ config = yaml.safe_load(open("config/config.yaml"))
22
+ words = ['grad-cam', 'lime', 'sidu', 'rise']
23
+ options = ['-', '1', '2', '3', '4']
24
+ class_names = config['dataset'][config['dataset']['name']]['class_names']
25
+ data_dir = os.path.join(config['dataset']['path'], config['dataset']['name'])
26
+
27
+ with gr.Blocks(theme=gr.themes.Glass(), css=css) as demo:
28
+ # Main App Components
29
+ title = gr.Markdown("# Saliency evaluation - experiment 1")
30
+ user_state = gr.State(0)
31
+ answers = gr.State([])
32
+ start_time = gr.State(time.time())
33
+
34
+ concepts = load_csv_concepts(data_dir)
35
+
36
+ gr.Markdown("### Image examples")
37
+ with gr.Row():
38
+ count = user_state if isinstance(user_state, int) else user_state.value
39
+ images = load_example_images(count, data_dir)
40
+ img1 = gr.Image(images[0])
41
+ img2 = gr.Image(images[1])
42
+ img3 = gr.Image(images[2])
43
+ img4 = gr.Image(images[3])
44
+ img5 = gr.Image(images[4])
45
+ img6 = gr.Image(images[5])
46
+ img7 = gr.Image(images[6])
47
+ img8 = gr.Image(images[7])
48
+ img9 = gr.Image(images[8])
49
+ img10 = gr.Image(images[9])
50
+ img11 = gr.Image(images[10])
51
+ img12 = gr.Image(images[11])
52
+ img13 = gr.Image(images[12])
53
+ img14 = gr.Image(images[13])
54
+ img15 = gr.Image(images[14])
55
+ img16 = gr.Image(images[15])
56
+
57
+ count = user_state if isinstance(user_state, int) else user_state.value
58
+ row = concepts.iloc[count]
59
+ question = gr.Markdown(f"### Sort the following saliency maps according to which of them better explains the class {class_names[count]}.", visible=False)
60
+
61
+ with gr.Row():
62
+ target_img_label = gr.Markdown(f"Target image: **{class_names[user_state.value]}**")
63
+ gr.Markdown("Grad-cam")
64
+ gr.Markdown("Lime")
65
+ gr.Markdown("Sidu")
66
+ gr.Markdown("Rise")
67
+
68
+ with gr.Row():
69
+ count = user_state if isinstance(user_state, int) else user_state.value
70
+ images = load_image_and_saliency(count, data_dir)
71
+ target_img = gr.Image(images[0], elem_classes="main-image delay", visible=False)
72
+ saliency_gradcam = gr.Image(images[1], elem_classes="main-image", visible=False)
73
+ saliency_lime = gr.Image(images[2], elem_classes="main-image", visible=False)
74
+ saliency_sidu = gr.Image(images[4], elem_classes="main-image", visible=False)
75
+ saliency_rise = gr.Image(images[3], elem_classes="main-image", visible=False)
76
+
77
+
78
+ with gr.Row():
79
+ dropdown1 = gr.Dropdown(choices=options, label="grad-cam", visible=False)
80
+ dropdown2 = gr.Dropdown(choices=options, label="lime", visible=False)
81
+ dropdown3 = gr.Dropdown(choices=options, label="sidu", visible=False)
82
+ dropdown4 = gr.Dropdown(choices=options, label="rise", visible=False)
83
+
84
+ continue_button = gr.Button("Continue")
85
+ submit_button = gr.Button("Submit", visible=False)
86
+ finish_button = gr.Button("Finish", visible=False)
87
+
88
+ def update_images(user_state):
89
+ count = user_state if isinstance(user_state, int) else user_state.value
90
+ if count < config['dataset'][config['dataset']['name']]['n_classes']:
91
+ images = load_image_and_saliency(count, data_dir)
92
+
93
+ # image examples
94
+ images = load_example_images(count, data_dir)
95
+ img1 = gr.Image(images[0], visible=True)
96
+ img2 = gr.Image(images[1], visible=True)
97
+ img3 = gr.Image(images[2], visible=True)
98
+ img4 = gr.Image(images[3], visible=True)
99
+ img5 = gr.Image(images[4], visible=True)
100
+ img6 = gr.Image(images[5], visible=True)
101
+ img7 = gr.Image(images[6], visible=True)
102
+ img8 = gr.Image(images[7], visible=True)
103
+ img9 = gr.Image(images[8], visible=True)
104
+ img10 = gr.Image(images[9], visible=True)
105
+ img11 = gr.Image(images[10], visible=True)
106
+ img12 = gr.Image(images[11], visible=True)
107
+ img13 = gr.Image(images[12], visible=True)
108
+ img14 = gr.Image(images[13], visible=True)
109
+ img15 = gr.Image(images[14], visible=True)
110
+ img16 = gr.Image(images[15], visible=True)
111
+ return img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16
112
+ else:
113
+ return img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16
114
+
115
+ def update_saliencies(dropdown1, dropdown2, dropdown3, dropdown4, user_state):
116
+ count = user_state if isinstance(user_state, int) else user_state.value
117
+ if count < config['dataset'][config['dataset']['name']]['n_classes']:
118
+ images = load_image_and_saliency(count, data_dir)
119
+ target_img = gr.Image(images[0], elem_classes="main-image", visible=True)
120
+ saliency_gradcam = gr.Image(images[1], elem_classes="main-image", visible=True)
121
+ saliency_lime = gr.Image(images[2], elem_classes="main-image", visible=True)
122
+ saliency_sidu = gr.Image(images[4], elem_classes="main-image", visible=True)
123
+ saliency_rise = gr.Image(images[3], elem_classes="main-image", visible=True)
124
+ return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu
125
+ else:
126
+ return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu
127
+
128
+ def update_state(state):
129
+ count = state if isinstance(state, int) else state.value
130
+ return gr.State(count + 1)
131
+
132
+ def update_img_label(state):
133
+ count = state if isinstance(state, int) else state.value
134
+ return f" Target image: **{class_names[count]}**"
135
+
136
+ def update_buttons():
137
+ submit_button = gr.Button("Submit", visible=False)
138
+ continue_button = gr.Button("Continue", visible=True)
139
+ return continue_button, submit_button
140
+
141
+ def show_view(state):
142
+ count = state if isinstance(state, int) else state.value
143
+ max_images = config['dataset'][config['dataset']['name']]['n_classes']
144
+ finish_button = gr.Button("Finish", visible=(count == max_images-1))
145
+ submit_button = gr.Button("Submit", visible=(count != max_images-1))
146
+ continue_button = gr.Button("Continue", visible=False)
147
+ return continue_button, submit_button, finish_button
148
+
149
+
150
+ def hide_view():
151
+ target_img = gr.Image(images[0], elem_classes="main-image", visible=False)
152
+ saliency_gradcam = gr.Image(images[1], elem_classes="main-image", visible=False)
153
+ saliency_lime = gr.Image(images[2], elem_classes="main-image", visible=False)
154
+ saliency_sidu = gr.Image(images[4], elem_classes="main-image", visible=False)
155
+ saliency_rise = gr.Image(images[3], elem_classes="main-image", visible=False)
156
+ question = gr.Markdown(f"### Sort the following saliency maps according to which of them better explains the class {class_names[count]}.", visible=False)
157
+ dropdown1 = gr.Dropdown(choices=options, label="grad-cam", visible=False)
158
+ dropdown2 = gr.Dropdown(choices=options, label="lime", visible=False)
159
+ dropdown3 = gr.Dropdown(choices=options, label="sidu", visible=False)
160
+ dropdown4 = gr.Dropdown(choices=options, label="rise", visible=False)
161
+ return question, target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise, dropdown1, dropdown2, dropdown3, dropdown4
162
+
163
+
164
+ def update_dropdowns():
165
+ dp1 = gr.Dropdown(choices=options, value=options[0], label="grad-cam", visible=True)
166
+ dp2 = gr.Dropdown(choices=options, value=options[0], label="lime", visible=True)
167
+ dp3 = gr.Dropdown(choices=options, value=options[0], label="sidu", visible=True)
168
+ dp4 = gr.Dropdown(choices=options, value=options[0], label="rise", visible=True)
169
+ return dp1, dp2, dp3, dp4
170
+
171
+ def update_questions(state):
172
+ concepts = load_csv_concepts(data_dir)
173
+ count = state if isinstance(state, int) else state.value
174
+ row = concepts.iloc[count]
175
+ return gr.Markdown(f"### Sort the following saliency maps according to which of them better explains the class {class_names[count]}.", visible=True)
176
+
177
+ def redirect():
178
+ pass
179
+
180
+ def save_results(answers):
181
+ api_token = os.getenv("HUGGINGFACE_TOKEN")
182
+ if not api_token:
183
+ raise ValueError("Hugging Face API token not found. Please set the HF_API_TOKEN environment variable.")
184
+
185
+ json_file_results = config['results']['exp1_dir'] # 'exp1'
186
+ JSON_DATASET_DIR = Path("json_dataset")
187
+ JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True)
188
+ JSON_DATASET_PATH = JSON_DATASET_DIR / f"train-{uuid4()}.json"
189
+ scheduler = CommitScheduler(
190
+ repo_id=f"results_{config['dataset']['name']}_{config['results']['exp1_dir']}", # The repo id
191
+ repo_type="dataset",
192
+ folder_path=JSON_DATASET_DIR,
193
+ path_in_repo="data",
194
+ token=api_token # Pass the token here
195
+ )
196
+
197
+ duration = time.time() - start_time.value
198
+
199
+ info_to_push = {
200
+ "user_id": time.time(),
201
+ "answer": {i: answer for i, answer in enumerate(answers)},
202
+ "duration": duration
203
+ }
204
+
205
+ # Save the results into huggingface hub
206
+ with scheduler.lock:
207
+ with JSON_DATASET_PATH.open("a") as f:
208
+ json.dump({
209
+ "user_id": info_to_push["user_id"],
210
+ "answers": info_to_push["answer"],
211
+ "duration": info_to_push["duration"],
212
+ "datetime": datetime.now().isoformat()
213
+ }, f)
214
+ f.write("\n")
215
+ scheduler.push_to_hub()
216
+
217
+ def check_answer(dropdown1, dropdown2, dropdown3, dropdown4):
218
+ if '-' in [dropdown1, dropdown2, dropdown3, dropdown4]:
219
+ raise gr.Error('Please select a value for each saliency method')
220
+ # check if all values are different 1,2,3,4
221
+ if len(set([dropdown1, dropdown2, dropdown3, dropdown4])) < 4:
222
+ print(set([dropdown1, dropdown2, dropdown3, dropdown4]))
223
+ raise gr.Error('Please select different values for each saliency method')
224
+
225
+ def add_answer(dropdown1,dropdown2,dropdown3,dropdown4, answers):
226
+ rank = [dropdown1,dropdown2,dropdown3,dropdown4]
227
+ answers.append(rank)
228
+ return answers
229
+
230
+ submit_button.click(
231
+ check_answer,
232
+ inputs=[dropdown1, dropdown2, dropdown3, dropdown4]
233
+ ).success(
234
+ update_state,
235
+ inputs=user_state,
236
+ outputs=user_state
237
+ ).then(
238
+ add_answer,
239
+ inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers],
240
+ outputs=answers
241
+ ).then(
242
+ update_img_label,
243
+ inputs=user_state,
244
+ outputs=target_img_label
245
+ ).then(
246
+ update_images,
247
+ inputs=user_state,
248
+ outputs=[img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16]
249
+ ).then(
250
+ update_buttons,
251
+ outputs={continue_button, submit_button}
252
+ ).then(
253
+ hide_view,
254
+ outputs={question, target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise, dropdown1, dropdown2, dropdown3, dropdown4}
255
+ )
256
+
257
+ continue_button.click(
258
+ show_view,
259
+ inputs=user_state,
260
+ outputs={continue_button, submit_button, finish_button}
261
+ ).then(
262
+ update_img_label,
263
+ inputs=user_state,
264
+ outputs=target_img_label
265
+ ).then(
266
+ update_saliencies,
267
+ inputs=[dropdown1, dropdown2, dropdown3, dropdown4, user_state],
268
+ outputs={target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise},
269
+ ).then(
270
+ update_questions,
271
+ inputs=user_state,
272
+ outputs=question
273
+ ).then(
274
+ update_dropdowns,
275
+ outputs={dropdown1, dropdown2, dropdown3, dropdown4}
276
+ )
277
+
278
+
279
+ finish_button.click(
280
+ add_answer, inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers],outputs=answers
281
+ ).then(
282
+ save_results, inputs=answers
283
+ ).then(
284
+ redirect, js="window.location = 'https://marcoparola.github.io/saliency-evaluation-app/end'")
285
+
286
+ demo.load()
287
+ demo.launch(root_path='/')
288
+
289
+ if __name__ == "__main__":
290
+ main()
config/config.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_dir: data
2
+ image_dir: images
3
+ saliency_dir: saliency
4
+ repo_id: "MarcoParola/saliency-evaluation"
5
+
6
+ gui:
7
+ max_img_examples: 16
8
+
9
+ experiments: exp1
10
+
11
+ results:
12
+ save_dir: results
13
+ exp1_dir: exp1
14
+ exp2_dir: exp2
15
+
16
+ dataset:
17
+ name: intel_image
18
+ path: data
19
+ intel_image:
20
+ n_classes: 6
21
+ class_names: ['BUILDING', 'FOREST', 'GLACIER', 'MOUNTAIN', 'SEA', 'STREET']
22
+ imagenette:
23
+ n_classes: 10
24
+ class_names: ['tench', 'English springer', 'cassette player', 'chain saw', 'church', 'French horn', 'garbage truck', 'gas pump', 'golf ball', 'parachute']
25
+
26
+ saliency_methods:
27
+ - gradcam
28
+ - lime
29
+ - sidu
30
+ - rise
data/intel_image/concepts_by_class.csv ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ class, concept1, concept2, concept3, concept4, concept5, concept6, concept7, concept8, concept9, concept10, concept11, concept12, concept13, concept14, concept15, concept16
2
+ buildings, Roof, Window, Facade, Wall, Boat, Tree, Sky, Car, Streetlights, Sidewalk, Beach, Vegetation, Water, Mountain Peak, Rock, Ice
3
+ forest, Vegetation, Tree, Water, Sidewalk, Facade, Sky, Beach, Wall, Rock, Window, Ice, Roof, Streetlights, Car, Mountain Peak, Boat
4
+ glacier, Ice, Rock, Mountain Peak, Water, Wall, Beach, Sky, Vegetation, Sidewalk, Facade, Roof, Tree, Window, Boat, Streetlights, Car
5
+ mountain, Mountain Peak, Rock, Vegetation, Sky, Tree, Ice, Water, Beach, Wall, Facade, Roof, Boat, Sidewalk, Window, Streetlights, Car
6
+ sea, Water, Boat, Beach, Sky, Rock, Sidewalk, Wall, Ice, Roof, Vegetation, Facade, Mountain Peak, Tree, Streetlights, Window, Car
7
+ street, Car, Streetlights, Sidewalk, Boat, Wall, Facade, Tree, Roof, Beach, Sky, Window, Vegetation, Water, Rock, Mountain Peak, Ice
json_dataset/train-a2d46b18-1281-4a68-a405-49720965a1c7.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"user_id": 1738943192.9120383, "answers": {"0": ["-", "-", "-", "-"], "1": ["-", "-", "-", "-"], "2": ["-", "-", "-", "-"], "3": ["-", "-", "-", "-"], "4": ["-", "-", "-", "-"], "5": ["-", "-", "-", "-"]}, "datetime": "2025-02-07T16:46:32.916261"}
src/style.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ css = """
2
+ #gallery {
3
+ height: 300px;
4
+ }
5
+
6
+ .main-image {
7
+ width: 200px;
8
+ height: 200px;
9
+ object-fit: cover;
10
+ }
11
+
12
+ .gallery-textlabel > * {
13
+ h2 {
14
+ font-weight: medium;
15
+ text-align: center;
16
+ margin-top: 1px;
17
+ padding: 0px;
18
+ font-size: 1em;
19
+ }
20
+ .svelte-i3tvor {
21
+ display:none;
22
+ visibility: hidden;
23
+ font-size: 0.02em;
24
+ }
25
+ }
26
+ """
src/user.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ from threading import Lock
4
+
5
+ class UserID:
6
+ def __init__(self):
7
+ self.lock = Lock()
8
+ self.counter = 0
9
+ if os.path.exists('global_variable.csv'):
10
+ df = pd.read_csv('global_variable.csv')
11
+ self.counter = df['value'][0]
12
+
13
+ def increment(self):
14
+ with self.lock:
15
+ self.counter += 1
16
+ df = pd.DataFrame({'value': [self.counter]})
17
+ df.to_csv('global_variable.csv', index=False)
18
+ return self.counter
src/utils.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ from huggingface_hub import HfApi, HfFolder
4
+ import yaml
5
+ import numpy as np
6
+ import time
7
+
8
+ config = yaml.safe_load(open("./config/config.yaml"))
9
+
10
+ def load_image_and_saliency(class_idx, data_dir):
11
+ path = os.path.join(data_dir, 'images', str(class_idx))
12
+ images = os.listdir(path)
13
+ # pick a random image
14
+ # set random seed usiing time
15
+ np.random.seed(int(time.time()))
16
+ id = np.random.randint(0, len(images))
17
+ image = os.path.join(path, images[id])
18
+ gradcam_image = os.path.join(data_dir, 'saliency', 'gradcam', images[id])
19
+ lime_image = os.path.join(data_dir, 'saliency', 'lime', images[id])
20
+ sidu_image = os.path.join(data_dir, 'saliency', 'sidu', images[id])
21
+ rise_image = os.path.join(data_dir, 'saliency', 'rise', images[id])
22
+ return image, gradcam_image, lime_image, sidu_image, rise_image
23
+
24
+ def load_example_images(class_idx, data_dir, max_images=16):
25
+ path = os.path.join(data_dir, 'images', str(class_idx))
26
+ images = os.listdir(path)
27
+ # set random seed usiing time
28
+ np.random.seed(int(time.time()))
29
+ ids = np.random.choice(len(images), max_images, replace=False)
30
+ images = [os.path.join(path, images[id]) for id in ids]
31
+ return images
32
+
33
+ # Function to load words based on global variable
34
+ def load_words(idx):
35
+ words = [f"word_{idx}_{i}" for i in range(20)]
36
+ return words
37
+
38
+
39
+ def load_csv_concepts(data_dir):
40
+ # Load data from csv
41
+ data = pd.read_csv(os.path.join(data_dir, 'concepts_by_class.csv'))
42
+ return data
43
+