|
|
import gradio as gr |
|
|
import yaml |
|
|
import random |
|
|
import os |
|
|
import json |
|
|
import time |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from huggingface_hub import CommitScheduler, HfApi |
|
|
|
|
|
from src.utils import load_words, load_example_images, load_csv_concepts, generate_random_ids |
|
|
from src.style import css |
|
|
from src.user import UserID |
|
|
|
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
from uuid import uuid4 |
|
|
import json |
|
|
from huggingface_hub import CommitScheduler |
|
|
|
|
|
def main(): |
|
|
config = yaml.safe_load(open("config/config.yaml")) |
|
|
class_names = config['dataset'][config['dataset']['name']]['class_names'] |
|
|
data_dir = os.path.join(config['dataset']['path'], config['dataset']['name']) |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Glass(), css=css) as demo: |
|
|
|
|
|
title = gr.Markdown("# Saliency evaluation - experiment 2") |
|
|
user_state = gr.State(0) |
|
|
answers = gr.State([]) |
|
|
random_answer_order = gr.State({}) |
|
|
start_time = gr.State(time.time()) |
|
|
|
|
|
target_img_label = gr.Markdown(f"Target class: **{class_names[user_state.value]}**") |
|
|
question = gr.Markdown() |
|
|
|
|
|
concepts = load_csv_concepts(data_dir) |
|
|
|
|
|
concept_checkboxes = gr.CheckboxGroup( |
|
|
['c1, c2, c3', 'c4, c5, c6', 'c7, c8, c9'], |
|
|
label=f"Choose the concept set that better describes the target class", |
|
|
visible=False |
|
|
) |
|
|
|
|
|
gr.Markdown("### Image examples of the same class") |
|
|
with gr.Row(): |
|
|
count = user_state if isinstance(user_state, int) else user_state.value |
|
|
images = load_example_images(count, data_dir) |
|
|
img1 = gr.Image(images[0]) |
|
|
img2 = gr.Image(images[1]) |
|
|
img3 = gr.Image(images[2]) |
|
|
img4 = gr.Image(images[3]) |
|
|
img5 = gr.Image(images[4]) |
|
|
img6 = gr.Image(images[5]) |
|
|
img7 = gr.Image(images[6]) |
|
|
img8 = gr.Image(images[7]) |
|
|
img9 = gr.Image(images[8]) |
|
|
img10 = gr.Image(images[9]) |
|
|
img11 = gr.Image(images[10]) |
|
|
img12 = gr.Image(images[11]) |
|
|
img13 = gr.Image(images[12]) |
|
|
img14 = gr.Image(images[13]) |
|
|
img15 = gr.Image(images[14]) |
|
|
img16 = gr.Image(images[15]) |
|
|
|
|
|
continue_button = gr.Button("Continue") |
|
|
submit_button = gr.Button("Submit", visible=False) |
|
|
finish_button = gr.Button("Finish", visible=False) |
|
|
|
|
|
def update_label(concept_checkboxes, user_state): |
|
|
|
|
|
count = user_state if isinstance(user_state, int) else user_state.value |
|
|
if count < config['dataset'][config['dataset']['name']]['n_classes']: |
|
|
|
|
|
|
|
|
images = load_example_images(count, data_dir) |
|
|
img1 = gr.Image(images[0]) |
|
|
img2 = gr.Image(images[1]) |
|
|
img3 = gr.Image(images[2]) |
|
|
img4 = gr.Image(images[3]) |
|
|
img5 = gr.Image(images[4]) |
|
|
img6 = gr.Image(images[5]) |
|
|
img7 = gr.Image(images[6]) |
|
|
img8 = gr.Image(images[7]) |
|
|
img9 = gr.Image(images[8]) |
|
|
img10 = gr.Image(images[9]) |
|
|
img11 = gr.Image(images[10]) |
|
|
img12 = gr.Image(images[11]) |
|
|
img13 = gr.Image(images[12]) |
|
|
img14 = gr.Image(images[13]) |
|
|
img15 = gr.Image(images[14]) |
|
|
img16 = gr.Image(images[15]) |
|
|
return img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16 |
|
|
else: |
|
|
return img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16 |
|
|
|
|
|
def update_state(state): |
|
|
count = state if isinstance(state, int) else state.value |
|
|
return gr.State(count + 1) |
|
|
|
|
|
def update_img_label(state): |
|
|
count = state if isinstance(state, int) else state.value |
|
|
return f"### Target class: {class_names[count]}" |
|
|
|
|
|
def update_buttons(): |
|
|
submit_button = gr.Button("Submit", visible=False) |
|
|
continue_button = gr.Button("Continue", visible=True) |
|
|
return continue_button, submit_button |
|
|
|
|
|
def update_continue_button(state): |
|
|
count = state if isinstance(state, int) else state.value |
|
|
max_images = config['dataset'][config['dataset']['name']]['n_classes'] |
|
|
finish_button = gr.Button("Finish", visible=(count == max_images-1)) |
|
|
submit_button = gr.Button("Submit", visible=(count != max_images-1)) |
|
|
continue_button = gr.Button("Continue", visible=False) |
|
|
return continue_button, submit_button, finish_button |
|
|
|
|
|
|
|
|
def update_checkbox(user_state, random_answer_order): |
|
|
count = user_state if isinstance(user_state, int) else user_state.value |
|
|
|
|
|
row = concepts.iloc[count] |
|
|
keys = concepts.keys() |
|
|
random_ids = generate_random_ids() |
|
|
tmp = [] |
|
|
for i in range(3): |
|
|
t = [] |
|
|
for j in range(3): |
|
|
t.append(int(random_ids[i][j])) |
|
|
tmp.append(t) |
|
|
random_ids = tmp |
|
|
|
|
|
random_order = np.random.permutation(3) |
|
|
|
|
|
print('random_ids:', random_ids) |
|
|
print('random_order:', random_order) |
|
|
random_answer_order[count] = { |
|
|
"random_ids": random_ids, |
|
|
"random_order": random_order |
|
|
} |
|
|
concept_checkboxes = gr.CheckboxGroup( |
|
|
choices = [ |
|
|
(f'{row[keys[random_ids[random_order[0]][0]]]}, {row[keys[random_ids[random_order[0]][1]]]}, {row[keys[random_ids[random_order[0]][2]]]}', int(random_order[0])), |
|
|
(f'{row[keys[random_ids[random_order[1]][0]]]}, {row[keys[random_ids[random_order[1]][1]]]}, {row[keys[random_ids[random_order[1]][2]]]}', int(random_order[1])), |
|
|
(f'{row[keys[random_ids[random_order[2]][0]]]}, {row[keys[random_ids[random_order[2]][1]]]}, {row[keys[random_ids[random_order[2]][2]]]}', int(random_order[2])) |
|
|
], |
|
|
label=f"Choose the concept set that better describes the class {class_names[count]}", |
|
|
value=None, |
|
|
visible=True |
|
|
) |
|
|
return random_answer_order, concept_checkboxes |
|
|
|
|
|
def hide_checkbox(): |
|
|
concept_checkboxes = gr.CheckboxGroup( |
|
|
choices = ['c10, c2, c3','c4, c5, c6','c7, c8, c9'], |
|
|
label=f"Choose the concept set that better describes the target class", |
|
|
value=None, |
|
|
visible=False |
|
|
) |
|
|
return concept_checkboxes |
|
|
|
|
|
def redirect(): |
|
|
pass |
|
|
|
|
|
def save_results(answers, random_answer_order): |
|
|
rand_ids = [random_answer_order[i]['random_ids'] for i in range(len(random_answer_order))] |
|
|
|
|
|
rand_order = [random_answer_order[i]['random_order'] for i in range(len(random_answer_order))] |
|
|
|
|
|
api_token = os.getenv("HFTOKEN") |
|
|
if not api_token: |
|
|
raise ValueError("Hugging Face API token not found. Please set the HF_API_TOKEN environment variable.") |
|
|
|
|
|
json_file_results = config['results']['exp1_dir'] |
|
|
JSON_DATASET_DIR = Path("json_dataset") |
|
|
JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True) |
|
|
JSON_DATASET_PATH = JSON_DATASET_DIR / f"train-{uuid4()}.json" |
|
|
scheduler = CommitScheduler( |
|
|
repo_id=f"results_{config['dataset']['name']}_{config['results']['exp2_dir']}", |
|
|
repo_type="dataset", |
|
|
folder_path=JSON_DATASET_DIR, |
|
|
path_in_repo="data", |
|
|
token=api_token |
|
|
) |
|
|
|
|
|
duration = time.time() - start_time.value |
|
|
|
|
|
info_to_push = { |
|
|
"user_id": time.time(), |
|
|
"answer": {i: answer for i, answer in enumerate(answers)}, |
|
|
"random_ids": {i: [list(elem) for elem in rand_id] for i, rand_id in enumerate(rand_ids)}, |
|
|
"random_order": {i: [int(elem) for elem in rand_o] for i, rand_o in enumerate(rand_order)}, |
|
|
"duration": duration, |
|
|
} |
|
|
print('INFO TO PUSH:', info_to_push) |
|
|
|
|
|
|
|
|
with scheduler.lock: |
|
|
with JSON_DATASET_PATH.open("a") as f: |
|
|
json.dump({ |
|
|
"user_id": info_to_push["user_id"], |
|
|
"answers": info_to_push["answer"], |
|
|
|
|
|
"random_ids": {i: [list(elem) for elem in rand_id] for i, rand_id in enumerate(rand_ids)}, |
|
|
"random_order": info_to_push["random_order"], |
|
|
"duration": info_to_push["duration"], |
|
|
"datetime": datetime.now().isoformat() |
|
|
}, f) |
|
|
f.write("\n") |
|
|
scheduler.push_to_hub() |
|
|
|
|
|
def check_answer(concept_checkboxes): |
|
|
|
|
|
if len(concept_checkboxes) > 1: |
|
|
raise gr.Error("Please select only one concept set") |
|
|
if len(concept_checkboxes) == 0: |
|
|
raise gr.Error("Please select a concept set") |
|
|
|
|
|
def add_answer(concept_checkboxes, answers): |
|
|
answers.append(concept_checkboxes) |
|
|
print('ANSWERS:', answers, concept_checkboxes) |
|
|
return answers |
|
|
|
|
|
submit_button.click( |
|
|
check_answer, |
|
|
inputs=concept_checkboxes |
|
|
).success( |
|
|
update_state, |
|
|
inputs=user_state, |
|
|
outputs=user_state |
|
|
).then( |
|
|
add_answer, |
|
|
inputs=[concept_checkboxes, answers], |
|
|
outputs=answers |
|
|
).then( |
|
|
update_img_label, |
|
|
inputs=user_state, |
|
|
outputs=target_img_label |
|
|
).then( |
|
|
update_buttons, |
|
|
outputs={continue_button, submit_button} |
|
|
).then( |
|
|
hide_checkbox, |
|
|
outputs=concept_checkboxes |
|
|
).then( |
|
|
update_label, |
|
|
inputs=[concept_checkboxes, user_state], |
|
|
outputs={img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16}, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
continue_button.click( |
|
|
update_continue_button, |
|
|
inputs=user_state, |
|
|
outputs={continue_button, submit_button, finish_button} |
|
|
).then( |
|
|
update_checkbox, |
|
|
inputs=[user_state, random_answer_order], |
|
|
outputs={random_answer_order, concept_checkboxes} |
|
|
) |
|
|
|
|
|
finish_button.click( |
|
|
check_answer, inputs=concept_checkboxes |
|
|
).success( |
|
|
update_state, inputs=user_state, outputs=user_state |
|
|
).then( |
|
|
add_answer, inputs=[concept_checkboxes, answers],outputs=answers |
|
|
).then( |
|
|
save_results, inputs=[answers, random_answer_order] |
|
|
).then( |
|
|
redirect, js="window.location = 'https://marcoparola.github.io/saliency-evaluation-app/end'" |
|
|
) |
|
|
|
|
|
demo.load() |
|
|
demo.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|