Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import yaml | |
| import random | |
| import os | |
| import json | |
| import time | |
| from pathlib import Path | |
| from huggingface_hub import CommitScheduler, HfApi | |
| from src.utils import load_words, load_image_and_saliency, load_example_images, load_csv_concepts | |
| from src.style import css | |
| from src.user import UserID | |
| from datetime import datetime | |
| from pathlib import Path | |
| from uuid import uuid4 | |
| import json | |
| from huggingface_hub import CommitScheduler | |
| def main(): | |
| config = yaml.safe_load(open("config/config.yaml")) | |
| words = ['grad-cam', 'lime', 'sidu', 'rise'] | |
| options = ['-', '1', '2', '3', '4'] | |
| class_names = config['dataset'][config['dataset']['name']]['class_names'] | |
| data_dir = os.path.join(config['dataset']['path'], config['dataset']['name']) | |
| with gr.Blocks(theme=gr.themes.Glass(), css=css) as demo: | |
| # Main App Components | |
| title = gr.Markdown("# Saliency evaluation - experiment 1") | |
| user_state = gr.State(0) | |
| answers = gr.State([]) | |
| start_time = gr.State(time.time()) | |
| concepts = load_csv_concepts(data_dir) | |
| gr.Markdown("### Image examples") | |
| with gr.Row(): | |
| count = user_state if isinstance(user_state, int) else user_state.value | |
| images = load_example_images(count, data_dir) | |
| img1 = gr.Image(images[0]) | |
| img2 = gr.Image(images[1]) | |
| img3 = gr.Image(images[2]) | |
| img4 = gr.Image(images[3]) | |
| img5 = gr.Image(images[4]) | |
| img6 = gr.Image(images[5]) | |
| img7 = gr.Image(images[6]) | |
| img8 = gr.Image(images[7]) | |
| img9 = gr.Image(images[8]) | |
| img10 = gr.Image(images[9]) | |
| img11 = gr.Image(images[10]) | |
| img12 = gr.Image(images[11]) | |
| img13 = gr.Image(images[12]) | |
| img14 = gr.Image(images[13]) | |
| img15 = gr.Image(images[14]) | |
| img16 = gr.Image(images[15]) | |
| count = user_state if isinstance(user_state, int) else user_state.value | |
| row = concepts.iloc[count] | |
| question = gr.Markdown(f"### Sort the following saliency maps according to which of them better explains the class {class_names[count]}.", visible=False) | |
| with gr.Row(): | |
| target_img_label = gr.Markdown(f"Target image: **{class_names[user_state.value]}**") | |
| gr.Markdown("Grad-cam") | |
| gr.Markdown("Lime") | |
| gr.Markdown("Sidu") | |
| gr.Markdown("Rise") | |
| with gr.Row(): | |
| count = user_state if isinstance(user_state, int) else user_state.value | |
| images = load_image_and_saliency(count, data_dir) | |
| target_img = gr.Image(images[0], elem_classes="main-image delay", visible=False) | |
| saliency_gradcam = gr.Image(images[1], elem_classes="main-image", visible=False) | |
| saliency_lime = gr.Image(images[2], elem_classes="main-image", visible=False) | |
| saliency_sidu = gr.Image(images[4], elem_classes="main-image", visible=False) | |
| saliency_rise = gr.Image(images[3], elem_classes="main-image", visible=False) | |
| with gr.Row(): | |
| dropdown1 = gr.Dropdown(choices=options, label="grad-cam", visible=False) | |
| dropdown2 = gr.Dropdown(choices=options, label="lime", visible=False) | |
| dropdown3 = gr.Dropdown(choices=options, label="sidu", visible=False) | |
| dropdown4 = gr.Dropdown(choices=options, label="rise", visible=False) | |
| continue_button = gr.Button("Continue") | |
| submit_button = gr.Button("Submit", visible=False) | |
| finish_button = gr.Button("Finish", visible=False) | |
| def update_images(user_state): | |
| count = user_state if isinstance(user_state, int) else user_state.value | |
| if count < config['dataset'][config['dataset']['name']]['n_classes']: | |
| images = load_image_and_saliency(count, data_dir) | |
| # image examples | |
| images = load_example_images(count, data_dir) | |
| img1 = gr.Image(images[0], visible=True) | |
| img2 = gr.Image(images[1], visible=True) | |
| img3 = gr.Image(images[2], visible=True) | |
| img4 = gr.Image(images[3], visible=True) | |
| img5 = gr.Image(images[4], visible=True) | |
| img6 = gr.Image(images[5], visible=True) | |
| img7 = gr.Image(images[6], visible=True) | |
| img8 = gr.Image(images[7], visible=True) | |
| img9 = gr.Image(images[8], visible=True) | |
| img10 = gr.Image(images[9], visible=True) | |
| img11 = gr.Image(images[10], visible=True) | |
| img12 = gr.Image(images[11], visible=True) | |
| img13 = gr.Image(images[12], visible=True) | |
| img14 = gr.Image(images[13], visible=True) | |
| img15 = gr.Image(images[14], visible=True) | |
| img16 = gr.Image(images[15], visible=True) | |
| return img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16 | |
| else: | |
| return img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16 | |
| def update_saliencies(dropdown1, dropdown2, dropdown3, dropdown4, user_state): | |
| count = user_state if isinstance(user_state, int) else user_state.value | |
| if count < config['dataset'][config['dataset']['name']]['n_classes']: | |
| images = load_image_and_saliency(count, data_dir) | |
| target_img = gr.Image(images[0], elem_classes="main-image", visible=True) | |
| saliency_gradcam = gr.Image(images[1], elem_classes="main-image", visible=True) | |
| saliency_lime = gr.Image(images[2], elem_classes="main-image", visible=True) | |
| saliency_sidu = gr.Image(images[4], elem_classes="main-image", visible=True) | |
| saliency_rise = gr.Image(images[3], elem_classes="main-image", visible=True) | |
| return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu | |
| else: | |
| return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu | |
| def update_state(state): | |
| count = state if isinstance(state, int) else state.value | |
| return gr.State(count + 1) | |
| def update_img_label(state): | |
| count = state if isinstance(state, int) else state.value | |
| return f" Target image: **{class_names[count]}**" | |
| def update_buttons(): | |
| submit_button = gr.Button("Submit", visible=False) | |
| continue_button = gr.Button("Continue", visible=True) | |
| return continue_button, submit_button | |
| def show_view(state): | |
| count = state if isinstance(state, int) else state.value | |
| max_images = config['dataset'][config['dataset']['name']]['n_classes'] | |
| finish_button = gr.Button("Finish", visible=(count == max_images-1)) | |
| submit_button = gr.Button("Submit", visible=(count != max_images-1)) | |
| continue_button = gr.Button("Continue", visible=False) | |
| return continue_button, submit_button, finish_button | |
| def hide_view(): | |
| target_img = gr.Image(images[0], elem_classes="main-image", visible=False) | |
| saliency_gradcam = gr.Image(images[1], elem_classes="main-image", visible=False) | |
| saliency_lime = gr.Image(images[2], elem_classes="main-image", visible=False) | |
| saliency_sidu = gr.Image(images[4], elem_classes="main-image", visible=False) | |
| saliency_rise = gr.Image(images[3], elem_classes="main-image", visible=False) | |
| question = gr.Markdown(f"### Sort the following saliency maps according to which of them better explains the class {class_names[count]}.", visible=False) | |
| dropdown1 = gr.Dropdown(choices=options, label="grad-cam", visible=False) | |
| dropdown2 = gr.Dropdown(choices=options, label="lime", visible=False) | |
| dropdown3 = gr.Dropdown(choices=options, label="sidu", visible=False) | |
| dropdown4 = gr.Dropdown(choices=options, label="rise", visible=False) | |
| return question, target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise, dropdown1, dropdown2, dropdown3, dropdown4 | |
| def update_dropdowns(): | |
| dp1 = gr.Dropdown(choices=options, value=options[0], label="grad-cam", visible=True) | |
| dp2 = gr.Dropdown(choices=options, value=options[0], label="lime", visible=True) | |
| dp3 = gr.Dropdown(choices=options, value=options[0], label="sidu", visible=True) | |
| dp4 = gr.Dropdown(choices=options, value=options[0], label="rise", visible=True) | |
| return dp1, dp2, dp3, dp4 | |
| def update_questions(state): | |
| concepts = load_csv_concepts(data_dir) | |
| count = state if isinstance(state, int) else state.value | |
| row = concepts.iloc[count] | |
| return gr.Markdown(f"### Sort the following saliency maps according to which of them better explains the class {class_names[count]}.", visible=True) | |
| def redirect(): | |
| pass | |
| def save_results(answers): | |
| api_token = os.getenv("HUGGINGFACE_TOKEN") | |
| if not api_token: | |
| raise ValueError("Hugging Face API token not found. Please set the HF_API_TOKEN environment variable.") | |
| json_file_results = config['results']['exp1_dir'] # 'exp1' | |
| JSON_DATASET_DIR = Path("json_dataset") | |
| JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True) | |
| JSON_DATASET_PATH = JSON_DATASET_DIR / f"train-{uuid4()}.json" | |
| scheduler = CommitScheduler( | |
| repo_id=f"results_{config['dataset']['name']}_{config['results']['exp1_dir']}", # The repo id | |
| repo_type="dataset", | |
| folder_path=JSON_DATASET_DIR, | |
| path_in_repo="data", | |
| token=api_token # Pass the token here | |
| ) | |
| duration = time.time() - start_time.value | |
| info_to_push = { | |
| "user_id": time.time(), | |
| "answer": {i: answer for i, answer in enumerate(answers)}, | |
| "duration": duration | |
| } | |
| # Save the results into huggingface hub | |
| with scheduler.lock: | |
| with JSON_DATASET_PATH.open("a") as f: | |
| json.dump({ | |
| "user_id": info_to_push["user_id"], | |
| "answers": info_to_push["answer"], | |
| "duration": info_to_push["duration"], | |
| "datetime": datetime.now().isoformat() | |
| }, f) | |
| f.write("\n") | |
| scheduler.push_to_hub() | |
| def check_answer(dropdown1, dropdown2, dropdown3, dropdown4): | |
| if '-' in [dropdown1, dropdown2, dropdown3, dropdown4]: | |
| raise gr.Error('Please select a value for each saliency method') | |
| # check if all values are different 1,2,3,4 | |
| if len(set([dropdown1, dropdown2, dropdown3, dropdown4])) < 4: | |
| print(set([dropdown1, dropdown2, dropdown3, dropdown4])) | |
| raise gr.Error('Please select different values for each saliency method') | |
| def add_answer(dropdown1,dropdown2,dropdown3,dropdown4, answers): | |
| rank = [dropdown1,dropdown2,dropdown3,dropdown4] | |
| answers.append(rank) | |
| return answers | |
| submit_button.click( | |
| check_answer, | |
| inputs=[dropdown1, dropdown2, dropdown3, dropdown4] | |
| ).success( | |
| update_state, | |
| inputs=user_state, | |
| outputs=user_state | |
| ).then( | |
| add_answer, | |
| inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers], | |
| outputs=answers | |
| ).then( | |
| update_img_label, | |
| inputs=user_state, | |
| outputs=target_img_label | |
| ).then( | |
| update_images, | |
| inputs=user_state, | |
| outputs=[img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16] | |
| ).then( | |
| update_buttons, | |
| outputs={continue_button, submit_button} | |
| ).then( | |
| hide_view, | |
| outputs={question, target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise, dropdown1, dropdown2, dropdown3, dropdown4} | |
| ) | |
| continue_button.click( | |
| show_view, | |
| inputs=user_state, | |
| outputs={continue_button, submit_button, finish_button} | |
| ).then( | |
| update_img_label, | |
| inputs=user_state, | |
| outputs=target_img_label | |
| ).then( | |
| update_saliencies, | |
| inputs=[dropdown1, dropdown2, dropdown3, dropdown4, user_state], | |
| outputs={target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise}, | |
| ).then( | |
| update_questions, | |
| inputs=user_state, | |
| outputs=question | |
| ).then( | |
| update_dropdowns, | |
| outputs={dropdown1, dropdown2, dropdown3, dropdown4} | |
| ) | |
| finish_button.click( | |
| add_answer, inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers],outputs=answers | |
| ).then( | |
| save_results, inputs=answers | |
| ).then( | |
| redirect, js="window.location = 'https://marcoparola.github.io/saliency-evaluation-app/end'") | |
| demo.load() | |
| demo.launch(root_path='/') | |
| if __name__ == "__main__": | |
| main() |