|
|
import gradio as gr |
|
|
import yaml |
|
|
import random |
|
|
import os |
|
|
import json |
|
|
import time |
|
|
from pathlib import Path |
|
|
from huggingface_hub import CommitScheduler, HfApi |
|
|
|
|
|
from src.utils import load_words, get_random_image_id, load_image_and_saliency, load_example_images, load_csv_concepts |
|
|
from src.style import css |
|
|
from src.user import UserID |
|
|
|
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
from uuid import uuid4 |
|
|
import json |
|
|
from huggingface_hub import CommitScheduler |
|
|
|
|
|
def main(): |
|
|
config = yaml.safe_load(open("config/config.yaml")) |
|
|
words = ['grad-cam', 'lime', 'sidu', 'rise'] |
|
|
options = ['-', '1', '2', '3', '4'] |
|
|
class_names = config['dataset'][config['dataset']['name']]['class_names'] |
|
|
data_dir = os.path.join(config['dataset']['path'], config['dataset']['name']) |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Glass(), css=css) as demo: |
|
|
|
|
|
title = gr.Markdown("# Saliency evaluation - experiment 1") |
|
|
user_state = gr.State(0) |
|
|
answers = gr.State([]) |
|
|
img_ids = gr.State([]) |
|
|
start_time = gr.State(time.time()) |
|
|
|
|
|
concepts = load_csv_concepts(data_dir) |
|
|
|
|
|
gr.Markdown("### Image examples") |
|
|
with gr.Row(): |
|
|
count = user_state if isinstance(user_state, int) else user_state.value |
|
|
images = load_example_images(count, data_dir) |
|
|
img1 = gr.Image(images[0]) |
|
|
img2 = gr.Image(images[1]) |
|
|
img3 = gr.Image(images[2]) |
|
|
img4 = gr.Image(images[3]) |
|
|
img5 = gr.Image(images[4]) |
|
|
img6 = gr.Image(images[5]) |
|
|
img7 = gr.Image(images[6]) |
|
|
img8 = gr.Image(images[7]) |
|
|
img9 = gr.Image(images[8]) |
|
|
img10 = gr.Image(images[9]) |
|
|
img11 = gr.Image(images[10]) |
|
|
img12 = gr.Image(images[11]) |
|
|
img13 = gr.Image(images[12]) |
|
|
img14 = gr.Image(images[13]) |
|
|
img15 = gr.Image(images[14]) |
|
|
img16 = gr.Image(images[15]) |
|
|
|
|
|
count = user_state if isinstance(user_state, int) else user_state.value |
|
|
row = concepts.iloc[count] |
|
|
question = gr.Markdown(f"### Sort the following saliency maps according to which of them better explains the class {class_names[count]}.", visible=False) |
|
|
|
|
|
with gr.Row(): |
|
|
target_img_label = gr.Markdown(f"Target image: **{class_names[user_state.value]}**") |
|
|
gr.Markdown("Grad-cam") |
|
|
gr.Markdown("Lime") |
|
|
gr.Markdown("Sidu") |
|
|
gr.Markdown("Rise") |
|
|
|
|
|
with gr.Row(): |
|
|
count = user_state if isinstance(user_state, int) else user_state.value |
|
|
random_img_id = get_random_image_id(count, data_dir) |
|
|
img_ids = gr.State([random_img_id]) |
|
|
images = load_image_and_saliency(count, data_dir, random_img_id) |
|
|
target_img = gr.Image(images[0], elem_classes="main-image delay", visible=False) |
|
|
saliency_gradcam = gr.Image(images[1], elem_classes="main-image", visible=False) |
|
|
saliency_lime = gr.Image(images[2], elem_classes="main-image", visible=False) |
|
|
saliency_sidu = gr.Image(images[4], elem_classes="main-image", visible=False) |
|
|
saliency_rise = gr.Image(images[3], elem_classes="main-image", visible=False) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
dropdown1 = gr.Dropdown(choices=options, label="grad-cam", visible=False) |
|
|
dropdown2 = gr.Dropdown(choices=options, label="lime", visible=False) |
|
|
dropdown3 = gr.Dropdown(choices=options, label="sidu", visible=False) |
|
|
dropdown4 = gr.Dropdown(choices=options, label="rise", visible=False) |
|
|
|
|
|
continue_button = gr.Button("Continue") |
|
|
submit_button = gr.Button("Submit", visible=False) |
|
|
finish_button = gr.Button("Finish", visible=False) |
|
|
|
|
|
def update_images(user_state): |
|
|
count = user_state if isinstance(user_state, int) else user_state.value |
|
|
if count < config['dataset'][config['dataset']['name']]['n_classes']: |
|
|
|
|
|
|
|
|
|
|
|
images = load_example_images(count, data_dir) |
|
|
img1 = gr.Image(images[0], visible=True) |
|
|
img2 = gr.Image(images[1], visible=True) |
|
|
img3 = gr.Image(images[2], visible=True) |
|
|
img4 = gr.Image(images[3], visible=True) |
|
|
img5 = gr.Image(images[4], visible=True) |
|
|
img6 = gr.Image(images[5], visible=True) |
|
|
img7 = gr.Image(images[6], visible=True) |
|
|
img8 = gr.Image(images[7], visible=True) |
|
|
img9 = gr.Image(images[8], visible=True) |
|
|
img10 = gr.Image(images[9], visible=True) |
|
|
img11 = gr.Image(images[10], visible=True) |
|
|
img12 = gr.Image(images[11], visible=True) |
|
|
img13 = gr.Image(images[12], visible=True) |
|
|
img14 = gr.Image(images[13], visible=True) |
|
|
img15 = gr.Image(images[14], visible=True) |
|
|
img16 = gr.Image(images[15], visible=True) |
|
|
return img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16 |
|
|
else: |
|
|
return img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16 |
|
|
|
|
|
def update_saliencies(dropdown1, dropdown2, dropdown3, dropdown4, user_state, img_ids): |
|
|
count = user_state if isinstance(user_state, int) else user_state.value |
|
|
if count < config['dataset'][config['dataset']['name']]['n_classes']: |
|
|
img_id = img_ids[-1] if isinstance(img_ids[-1], int) else img_ids[-1].value |
|
|
print(f"Updating saliency maps for class index: {count} and image ID: {img_id}") |
|
|
images = load_image_and_saliency(count, data_dir, img_id) |
|
|
target_img = gr.Image(images[0], elem_classes="main-image", visible=True) |
|
|
saliency_gradcam = gr.Image(images[1], elem_classes="main-image", visible=True) |
|
|
saliency_lime = gr.Image(images[2], elem_classes="main-image", visible=True) |
|
|
saliency_sidu = gr.Image(images[4], elem_classes="main-image", visible=True) |
|
|
saliency_rise = gr.Image(images[3], elem_classes="main-image", visible=True) |
|
|
return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu |
|
|
else: |
|
|
return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu |
|
|
|
|
|
def update_state(state): |
|
|
count = state if isinstance(state, int) else state.value |
|
|
print(f"Updating state: {count}") |
|
|
return gr.State(count + 1) |
|
|
|
|
|
def update_img_label(state): |
|
|
count = state if isinstance(state, int) else state.value |
|
|
return f" Target image: **{class_names[count]}**" |
|
|
|
|
|
def update_random_img_id(img_ids, user_state): |
|
|
count = user_state if isinstance(user_state, int) else user_state.value |
|
|
random_img_id = get_random_image_id(count, data_dir) |
|
|
img_ids.append(random_img_id) |
|
|
return img_ids |
|
|
|
|
|
def update_buttons(): |
|
|
submit_button = gr.Button("Submit", visible=False) |
|
|
continue_button = gr.Button("Continue", visible=True) |
|
|
return continue_button, submit_button |
|
|
|
|
|
def show_view(state): |
|
|
count = state if isinstance(state, int) else state.value |
|
|
max_images = config['dataset'][config['dataset']['name']]['n_classes'] |
|
|
finish_button = gr.Button("Finish", visible=(count == max_images-1)) |
|
|
submit_button = gr.Button("Submit", visible=(count != max_images-1)) |
|
|
continue_button = gr.Button("Continue", visible=False) |
|
|
return continue_button, submit_button, finish_button |
|
|
|
|
|
|
|
|
def hide_view(): |
|
|
target_img = gr.Image(images[0], elem_classes="main-image", visible=False) |
|
|
saliency_gradcam = gr.Image(images[1], elem_classes="main-image", visible=False) |
|
|
saliency_lime = gr.Image(images[2], elem_classes="main-image", visible=False) |
|
|
saliency_sidu = gr.Image(images[4], elem_classes="main-image", visible=False) |
|
|
saliency_rise = gr.Image(images[3], elem_classes="main-image", visible=False) |
|
|
question = gr.Markdown(f"### Sort the following saliency maps according to which of them better explains the class {class_names[count]}.", visible=False) |
|
|
dropdown1 = gr.Dropdown(choices=options, label="grad-cam", visible=False) |
|
|
dropdown2 = gr.Dropdown(choices=options, label="lime", visible=False) |
|
|
dropdown3 = gr.Dropdown(choices=options, label="sidu", visible=False) |
|
|
dropdown4 = gr.Dropdown(choices=options, label="rise", visible=False) |
|
|
return question, target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise, dropdown1, dropdown2, dropdown3, dropdown4 |
|
|
|
|
|
|
|
|
def update_dropdowns(): |
|
|
dp1 = gr.Dropdown(choices=options, value=options[0], label="grad-cam", visible=True) |
|
|
dp2 = gr.Dropdown(choices=options, value=options[0], label="lime", visible=True) |
|
|
dp3 = gr.Dropdown(choices=options, value=options[0], label="sidu", visible=True) |
|
|
dp4 = gr.Dropdown(choices=options, value=options[0], label="rise", visible=True) |
|
|
return dp1, dp2, dp3, dp4 |
|
|
|
|
|
def update_questions(state): |
|
|
concepts = load_csv_concepts(data_dir) |
|
|
count = state if isinstance(state, int) else state.value |
|
|
row = concepts.iloc[count] |
|
|
return gr.Markdown(f"### Sort the following saliency maps according to which of them better explains the class {class_names[count]}.", visible=True) |
|
|
|
|
|
def redirect(): |
|
|
pass |
|
|
|
|
|
def save_results(answers, img_ids): |
|
|
api_token = os.getenv("HUGGINGFACE_TOKEN") |
|
|
if not api_token: |
|
|
raise ValueError("Hugging Face API token not found. Please set the HF_API_TOKEN environment variable.") |
|
|
|
|
|
json_file_results = config['results']['exp1_dir'] |
|
|
JSON_DATASET_DIR = Path("json_dataset") |
|
|
JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True) |
|
|
JSON_DATASET_PATH = JSON_DATASET_DIR / f"train-{uuid4()}.json" |
|
|
scheduler = CommitScheduler( |
|
|
repo_id=f"results_{config['dataset']['name']}_{config['results']['exp1_dir']}", |
|
|
repo_type="dataset", |
|
|
folder_path=JSON_DATASET_DIR, |
|
|
path_in_repo="data", |
|
|
token=api_token |
|
|
) |
|
|
|
|
|
duration = time.time() - start_time.value |
|
|
|
|
|
info_to_push = { |
|
|
"user_id": time.time(), |
|
|
"answer": {i: answer for i, answer in enumerate(answers)}, |
|
|
"img_ids": [img_id for img_id in img_ids], |
|
|
"duration": duration |
|
|
} |
|
|
|
|
|
|
|
|
with scheduler.lock: |
|
|
with JSON_DATASET_PATH.open("a") as f: |
|
|
json.dump({ |
|
|
"user_id": info_to_push["user_id"], |
|
|
"answers": info_to_push["answer"], |
|
|
"duration": info_to_push["duration"], |
|
|
"img_ids": info_to_push["img_ids"], |
|
|
"datetime": datetime.now().isoformat() |
|
|
}, f) |
|
|
f.write("\n") |
|
|
scheduler.push_to_hub() |
|
|
|
|
|
def check_answer(dropdown1, dropdown2, dropdown3, dropdown4): |
|
|
if '-' in [dropdown1, dropdown2, dropdown3, dropdown4]: |
|
|
raise gr.Error('Please select a value for each saliency method') |
|
|
|
|
|
if len(set([dropdown1, dropdown2, dropdown3, dropdown4])) < 4: |
|
|
print(set([dropdown1, dropdown2, dropdown3, dropdown4])) |
|
|
raise gr.Error('Please select different values for each saliency method') |
|
|
|
|
|
def add_answer(dropdown1,dropdown2,dropdown3,dropdown4, answers): |
|
|
rank = [dropdown1,dropdown2,dropdown3,dropdown4] |
|
|
answers.append(rank) |
|
|
return answers |
|
|
|
|
|
submit_button.click( |
|
|
check_answer, |
|
|
inputs=[dropdown1, dropdown2, dropdown3, dropdown4] |
|
|
).success( |
|
|
update_state, |
|
|
inputs=user_state, |
|
|
outputs=user_state |
|
|
).then( |
|
|
add_answer, |
|
|
inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers], |
|
|
outputs=answers |
|
|
).then( |
|
|
update_random_img_id, |
|
|
inputs=[img_ids, user_state], |
|
|
outputs=img_ids |
|
|
).then( |
|
|
update_img_label, |
|
|
inputs=user_state, |
|
|
outputs=target_img_label |
|
|
).then( |
|
|
update_images, |
|
|
inputs=user_state, |
|
|
outputs=[img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16] |
|
|
).then( |
|
|
update_buttons, |
|
|
outputs={continue_button, submit_button} |
|
|
).then( |
|
|
hide_view, |
|
|
outputs={question, target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise, dropdown1, dropdown2, dropdown3, dropdown4} |
|
|
) |
|
|
|
|
|
continue_button.click( |
|
|
show_view, |
|
|
inputs=user_state, |
|
|
outputs={continue_button, submit_button, finish_button} |
|
|
).then( |
|
|
update_img_label, |
|
|
inputs=user_state, |
|
|
outputs=target_img_label |
|
|
).then( |
|
|
update_saliencies, |
|
|
inputs=[dropdown1, dropdown2, dropdown3, dropdown4, user_state, img_ids], |
|
|
outputs={target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise}, |
|
|
).then( |
|
|
update_questions, |
|
|
inputs=user_state, |
|
|
outputs=question |
|
|
).then( |
|
|
update_dropdowns, |
|
|
outputs={dropdown1, dropdown2, dropdown3, dropdown4} |
|
|
) |
|
|
|
|
|
|
|
|
finish_button.click( |
|
|
add_answer, inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers],outputs=answers |
|
|
).then( |
|
|
save_results, inputs=[answers, img_ids] |
|
|
).then( |
|
|
redirect, js="window.location = 'https://marcoparola.github.io/saliency-evaluation-app/end'") |
|
|
|
|
|
demo.load() |
|
|
demo.launch(root_path='/') |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |