Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import yaml | |
| import random | |
| import os | |
| import json | |
| from pathlib import Path | |
| from huggingface_hub import CommitScheduler, HfApi | |
| from src.utils import load_words, load_image_and_saliency, load_example_images | |
| from src.style import css | |
| from src.user import UserID | |
| def main(): | |
| config = yaml.safe_load(open("config/config.yaml")) | |
| words = ['grad-cam', 'lime', 'sidu', 'rise'] | |
| options = ['-', '1', '2', '3', '4'] | |
| class_names = config['dataset'][config['dataset']['name']]['class_names'] | |
| data_dir = os.path.join(config['dataset']['path'], config['dataset']['name']) | |
| with gr.Blocks(theme=gr.themes.Glass(), css=css) as demo: | |
| # Main App Components | |
| title = gr.Markdown("# Saliency evaluation - experiment 1") | |
| user_state = gr.State(0) | |
| #user_id = gr.State(load_global_variable()) | |
| answers = gr.State([]) | |
| with gr.Row(): | |
| target_img_label = gr.Markdown(f"Target image: **{class_names[user_state.value]}**") | |
| gr.Markdown("Grad-cam") | |
| gr.Markdown("Lime") | |
| gr.Markdown("Sidu") | |
| gr.Markdown("Rise") | |
| with gr.Row(): | |
| count = user_state if isinstance(user_state, int) else user_state.value | |
| images = load_image_and_saliency(count, data_dir) | |
| target_img = gr.Image(images[0], elem_classes="main-image") | |
| saliency_gradcam = gr.Image(images[1], elem_classes="main-image") | |
| saliency_lime = gr.Image(images[2], elem_classes="main-image") | |
| saliency_sidu = gr.Image(images[3], elem_classes="main-image") | |
| saliency_rise = gr.Image(images[4], elem_classes="main-image") | |
| with gr.Row(): | |
| dropdown1 = gr.Dropdown(choices=options, label="grad-cam") | |
| dropdown2 = gr.Dropdown(choices=options, label="lime") | |
| dropdown3 = gr.Dropdown(choices=options, label="sidu") | |
| dropdown4 = gr.Dropdown(choices=options, label="rise") | |
| gr.Markdown("### Image examples of the same class") | |
| with gr.Row(): | |
| count = user_state if isinstance(user_state, int) else user_state.value | |
| images = load_example_images(count, data_dir) | |
| img1 = gr.Image(images[0]) | |
| img2 = gr.Image(images[1]) | |
| img3 = gr.Image(images[2]) | |
| img4 = gr.Image(images[3]) | |
| img5 = gr.Image(images[4]) | |
| img6 = gr.Image(images[5]) | |
| img7 = gr.Image(images[6]) | |
| img8 = gr.Image(images[7]) | |
| img9 = gr.Image(images[8]) | |
| img10 = gr.Image(images[9]) | |
| img11 = gr.Image(images[10]) | |
| img12 = gr.Image(images[11]) | |
| img13 = gr.Image(images[12]) | |
| img14 = gr.Image(images[13]) | |
| img15 = gr.Image(images[14]) | |
| img16 = gr.Image(images[15]) | |
| submit_button = gr.Button("Submit") | |
| finish_button = gr.Button("Finish", visible=False) | |
| def update_images(dropdown1, dropdown2, dropdown3, dropdown4, user_state): | |
| count = user_state if isinstance(user_state, int) else user_state.value | |
| if count < config['dataset'][config['dataset']['name']]['n_classes']: | |
| images = load_image_and_saliency(count, data_dir) | |
| target_img = gr.Image(images[0], elem_classes="main-image") | |
| saliency_gradcam = gr.Image(images[1], elem_classes="main-image") | |
| saliency_lime = gr.Image(images[2], elem_classes="main-image") | |
| saliency_sidu = gr.Image(images[3], elem_classes="main-image") | |
| saliency_rise = gr.Image(images[4], elem_classes="main-image") | |
| # image examples | |
| images = load_example_images(count, data_dir) | |
| img1 = gr.Image(images[0]) | |
| img2 = gr.Image(images[1]) | |
| img3 = gr.Image(images[2]) | |
| img4 = gr.Image(images[3]) | |
| img5 = gr.Image(images[4]) | |
| img6 = gr.Image(images[5]) | |
| img7 = gr.Image(images[6]) | |
| img8 = gr.Image(images[7]) | |
| img9 = gr.Image(images[8]) | |
| img10 = gr.Image(images[9]) | |
| img11 = gr.Image(images[10]) | |
| img12 = gr.Image(images[11]) | |
| img13 = gr.Image(images[12]) | |
| img14 = gr.Image(images[13]) | |
| img15 = gr.Image(images[14]) | |
| img16 = gr.Image(images[15]) | |
| return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16 | |
| else: | |
| return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16 | |
| def update_state(state): | |
| count = state if isinstance(state, int) else state.value | |
| return gr.State(count + 1) | |
| def update_img_label(state): | |
| count = state if isinstance(state, int) else state.value | |
| return f"### Target image: {class_names[count]}" | |
| def update_buttons(state): | |
| count = state if isinstance(state, int) else state.value | |
| max_images = config['dataset'][config['dataset']['name']]['n_classes'] | |
| finish_button = gr.Button("Finish", visible=(count == max_images-1)) | |
| submit_button = gr.Button("Submit", visible=(count != max_images-1)) | |
| return submit_button, finish_button | |
| def update_dropdowns(): | |
| dp1 = gr.Dropdown(choices=options, value=options[0], label="grad-cam") | |
| dp2 = gr.Dropdown(choices=options, value=options[0], label="lime") | |
| dp3 = gr.Dropdown(choices=options, value=options[0], label="sidu") | |
| dp4 = gr.Dropdown(choices=options, value=options[0], label="rise") | |
| return dp1, dp2, dp3, dp4 | |
| def redirect(): | |
| pass | |
| def save_results(answers): | |
| api = HfApi() | |
| json_file_results = config['results']['exp1_dir'] | |
| JSON_DATASET_DIR = Path("json_dataset") | |
| JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True) | |
| JSON_DATASET_PATH = JSON_DATASET_DIR / json_file_results | |
| info_to_push = { | |
| "user_id": time.time(), | |
| "answer": {i: answer[i] for i in range(len(answer))}} | |
| # use api to push the results to the hub | |
| api.push_to_hub(info_to_push, json_file_results, use_temp_dir=True) | |
| def add_answer(dropdown1,dropdown2,dropdown3,dropdown4, answers): | |
| rank = [dropdown1,dropdown2,dropdown3,dropdown4] | |
| answers.append(rank) | |
| return answers | |
| submit_button.click( | |
| update_state, | |
| inputs=user_state, | |
| outputs=user_state | |
| ).then( | |
| add_answer, | |
| inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers], | |
| outputs=answers | |
| ).then( | |
| update_img_label, | |
| inputs=user_state, | |
| outputs=target_img_label | |
| ).then( | |
| update_buttons, | |
| inputs=user_state, | |
| outputs={submit_button, finish_button} | |
| ).then( | |
| update_images, | |
| inputs=[dropdown1, dropdown2, dropdown3, dropdown4, user_state], | |
| outputs={target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16}, | |
| ).then( | |
| update_dropdowns, | |
| outputs={dropdown1, dropdown2, dropdown3, dropdown4} | |
| ) | |
| finish_button.click( | |
| add_answer, inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers],outputs=answers | |
| ).then( | |
| save_results, inputs=answers | |
| ).then( | |
| redirect, js="window.location = 'https://marcoparola.github.io/saliency-evaluation-app/end'") | |
| demo.load() | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() | |