import gradio as gr import yaml import random import os import json import time from pathlib import Path from huggingface_hub import CommitScheduler, HfApi from src.utils import load_words, load_image_and_saliency, load_example_images, load_csv_concepts from src.style import css from src.user import UserID from datetime import datetime from pathlib import Path from uuid import uuid4 import json from huggingface_hub import CommitScheduler def main(): config = yaml.safe_load(open("config/config.yaml")) words = ['grad-cam', 'lime', 'sidu', 'rise'] options = ['-', '1', '2', '3', '4'] class_names = config['dataset'][config['dataset']['name']]['class_names'] data_dir = os.path.join(config['dataset']['path'], config['dataset']['name']) with gr.Blocks(theme=gr.themes.Glass(), css=css) as demo: # Main App Components title = gr.Markdown("# Saliency evaluation - experiment 1") user_state = gr.State(0) answers = gr.State([]) start_time = gr.State(time.time()) concepts = load_csv_concepts(data_dir) gr.Markdown("### Image examples") with gr.Row(): count = user_state if isinstance(user_state, int) else user_state.value images = load_example_images(count, data_dir) img1 = gr.Image(images[0]) img2 = gr.Image(images[1]) img3 = gr.Image(images[2]) img4 = gr.Image(images[3]) img5 = gr.Image(images[4]) img6 = gr.Image(images[5]) img7 = gr.Image(images[6]) img8 = gr.Image(images[7]) img9 = gr.Image(images[8]) img10 = gr.Image(images[9]) img11 = gr.Image(images[10]) img12 = gr.Image(images[11]) img13 = gr.Image(images[12]) img14 = gr.Image(images[13]) img15 = gr.Image(images[14]) img16 = gr.Image(images[15]) count = user_state if isinstance(user_state, int) else user_state.value row = concepts.iloc[count] question = gr.Markdown(f"### Sort the following saliency maps according to which of them better explains the class {class_names[count]}.", visible=False) with gr.Row(): target_img_label = gr.Markdown(f"Target image: **{class_names[user_state.value]}**") gr.Markdown("Grad-cam") gr.Markdown("Lime") gr.Markdown("Sidu") gr.Markdown("Rise") with gr.Row(): count = user_state if isinstance(user_state, int) else user_state.value images = load_image_and_saliency(count, data_dir) target_img = gr.Image(images[0], elem_classes="main-image delay", visible=False) saliency_gradcam = gr.Image(images[1], elem_classes="main-image", visible=False) saliency_lime = gr.Image(images[2], elem_classes="main-image", visible=False) saliency_sidu = gr.Image(images[4], elem_classes="main-image", visible=False) saliency_rise = gr.Image(images[3], elem_classes="main-image", visible=False) with gr.Row(): dropdown1 = gr.Dropdown(choices=options, label="grad-cam", visible=False) dropdown2 = gr.Dropdown(choices=options, label="lime", visible=False) dropdown3 = gr.Dropdown(choices=options, label="sidu", visible=False) dropdown4 = gr.Dropdown(choices=options, label="rise", visible=False) continue_button = gr.Button("Continue") submit_button = gr.Button("Submit", visible=False) finish_button = gr.Button("Finish", visible=False) def update_images(user_state): count = user_state if isinstance(user_state, int) else user_state.value if count < config['dataset'][config['dataset']['name']]['n_classes']: images = load_image_and_saliency(count, data_dir) # image examples images = load_example_images(count, data_dir) img1 = gr.Image(images[0], visible=True) img2 = gr.Image(images[1], visible=True) img3 = gr.Image(images[2], visible=True) img4 = gr.Image(images[3], visible=True) img5 = gr.Image(images[4], visible=True) img6 = gr.Image(images[5], visible=True) img7 = gr.Image(images[6], visible=True) img8 = gr.Image(images[7], visible=True) img9 = gr.Image(images[8], visible=True) img10 = gr.Image(images[9], visible=True) img11 = gr.Image(images[10], visible=True) img12 = gr.Image(images[11], visible=True) img13 = gr.Image(images[12], visible=True) img14 = gr.Image(images[13], visible=True) img15 = gr.Image(images[14], visible=True) img16 = gr.Image(images[15], visible=True) return img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16 else: return img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16 def update_saliencies(dropdown1, dropdown2, dropdown3, dropdown4, user_state): count = user_state if isinstance(user_state, int) else user_state.value if count < config['dataset'][config['dataset']['name']]['n_classes']: images = load_image_and_saliency(count, data_dir) target_img = gr.Image(images[0], elem_classes="main-image", visible=True) saliency_gradcam = gr.Image(images[1], elem_classes="main-image", visible=True) saliency_lime = gr.Image(images[2], elem_classes="main-image", visible=True) saliency_sidu = gr.Image(images[4], elem_classes="main-image", visible=True) saliency_rise = gr.Image(images[3], elem_classes="main-image", visible=True) return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu else: return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu def update_state(state): count = state if isinstance(state, int) else state.value return gr.State(count + 1) def update_img_label(state): count = state if isinstance(state, int) else state.value return f" Target image: **{class_names[count]}**" def update_buttons(): submit_button = gr.Button("Submit", visible=False) continue_button = gr.Button("Continue", visible=True) return continue_button, submit_button def show_view(state): count = state if isinstance(state, int) else state.value max_images = config['dataset'][config['dataset']['name']]['n_classes'] finish_button = gr.Button("Finish", visible=(count == max_images-1)) submit_button = gr.Button("Submit", visible=(count != max_images-1)) continue_button = gr.Button("Continue", visible=False) return continue_button, submit_button, finish_button def hide_view(): target_img = gr.Image(images[0], elem_classes="main-image", visible=False) saliency_gradcam = gr.Image(images[1], elem_classes="main-image", visible=False) saliency_lime = gr.Image(images[2], elem_classes="main-image", visible=False) saliency_sidu = gr.Image(images[4], elem_classes="main-image", visible=False) saliency_rise = gr.Image(images[3], elem_classes="main-image", visible=False) question = gr.Markdown(f"### Sort the following saliency maps according to which of them better explains the class {class_names[count]}.", visible=False) dropdown1 = gr.Dropdown(choices=options, label="grad-cam", visible=False) dropdown2 = gr.Dropdown(choices=options, label="lime", visible=False) dropdown3 = gr.Dropdown(choices=options, label="sidu", visible=False) dropdown4 = gr.Dropdown(choices=options, label="rise", visible=False) return question, target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise, dropdown1, dropdown2, dropdown3, dropdown4 def update_dropdowns(): dp1 = gr.Dropdown(choices=options, value=options[0], label="grad-cam", visible=True) dp2 = gr.Dropdown(choices=options, value=options[0], label="lime", visible=True) dp3 = gr.Dropdown(choices=options, value=options[0], label="sidu", visible=True) dp4 = gr.Dropdown(choices=options, value=options[0], label="rise", visible=True) return dp1, dp2, dp3, dp4 def update_questions(state): concepts = load_csv_concepts(data_dir) count = state if isinstance(state, int) else state.value row = concepts.iloc[count] return gr.Markdown(f"### Sort the following saliency maps according to which of them better explains the class {class_names[count]}.", visible=True) def redirect(): pass def save_results(answers): api_token = os.getenv("HUGGINGFACE_TOKEN") if not api_token: raise ValueError("Hugging Face API token not found. Please set the HF_API_TOKEN environment variable.") json_file_results = config['results']['exp1_dir'] # 'exp1' JSON_DATASET_DIR = Path("json_dataset") JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True) JSON_DATASET_PATH = JSON_DATASET_DIR / f"train-{uuid4()}.json" scheduler = CommitScheduler( repo_id=f"results_{config['dataset']['name']}_{config['results']['exp1_dir']}", # The repo id repo_type="dataset", folder_path=JSON_DATASET_DIR, path_in_repo="data", token=api_token # Pass the token here ) duration = time.time() - start_time.value info_to_push = { "user_id": time.time(), "answer": {i: answer for i, answer in enumerate(answers)}, "duration": duration } # Save the results into huggingface hub with scheduler.lock: with JSON_DATASET_PATH.open("a") as f: json.dump({ "user_id": info_to_push["user_id"], "answers": info_to_push["answer"], "duration": info_to_push["duration"], "datetime": datetime.now().isoformat() }, f) f.write("\n") scheduler.push_to_hub() def check_answer(dropdown1, dropdown2, dropdown3, dropdown4): if '-' in [dropdown1, dropdown2, dropdown3, dropdown4]: raise gr.Error('Please select a value for each saliency method') # check if all values are different 1,2,3,4 if len(set([dropdown1, dropdown2, dropdown3, dropdown4])) < 4: print(set([dropdown1, dropdown2, dropdown3, dropdown4])) raise gr.Error('Please select different values for each saliency method') def add_answer(dropdown1,dropdown2,dropdown3,dropdown4, answers): rank = [dropdown1,dropdown2,dropdown3,dropdown4] answers.append(rank) return answers submit_button.click( check_answer, inputs=[dropdown1, dropdown2, dropdown3, dropdown4] ).success( update_state, inputs=user_state, outputs=user_state ).then( add_answer, inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers], outputs=answers ).then( update_img_label, inputs=user_state, outputs=target_img_label ).then( update_images, inputs=user_state, outputs=[img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16] ).then( update_buttons, outputs={continue_button, submit_button} ).then( hide_view, outputs={question, target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise, dropdown1, dropdown2, dropdown3, dropdown4} ) continue_button.click( show_view, inputs=user_state, outputs={continue_button, submit_button, finish_button} ).then( update_img_label, inputs=user_state, outputs=target_img_label ).then( update_saliencies, inputs=[dropdown1, dropdown2, dropdown3, dropdown4, user_state], outputs={target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise}, ).then( update_questions, inputs=user_state, outputs=question ).then( update_dropdowns, outputs={dropdown1, dropdown2, dropdown3, dropdown4} ) finish_button.click( add_answer, inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers],outputs=answers ).then( save_results, inputs=answers ).then( redirect, js="window.location = 'https://marcoparola.github.io/saliency-evaluation-app/end'") demo.load() demo.launch(root_path='/') if __name__ == "__main__": main()