import gradio as gr import yaml import random import os import json import time import numpy as np from pathlib import Path from huggingface_hub import CommitScheduler, HfApi from src.utils import load_words, load_example_images, load_csv_concepts, generate_random_ids from src.style import css from src.user import UserID from datetime import datetime from pathlib import Path from uuid import uuid4 import json from huggingface_hub import CommitScheduler def main(): config = yaml.safe_load(open("config/config.yaml")) class_names = config['dataset'][config['dataset']['name']]['class_names'] data_dir = os.path.join(config['dataset']['path'], config['dataset']['name']) with gr.Blocks(theme=gr.themes.Glass(), css=css) as demo: # Main App Components title = gr.Markdown("# Saliency evaluation - experiment 2") user_state = gr.State(0) answers = gr.State([]) random_answer_order = gr.State({}) start_time = gr.State(time.time()) target_img_label = gr.Markdown(f"Target class: **{class_names[user_state.value]}**") question = gr.Markdown() concepts = load_csv_concepts(data_dir) concept_checkboxes = gr.CheckboxGroup( ['c1, c2, c3', 'c4, c5, c6', 'c7, c8, c9'], label=f"Choose the concept set that better describes the target class", visible=False ) gr.Markdown("### Image examples of the same class") with gr.Row(): count = user_state if isinstance(user_state, int) else user_state.value images = load_example_images(count, data_dir) img1 = gr.Image(images[0]) img2 = gr.Image(images[1]) img3 = gr.Image(images[2]) img4 = gr.Image(images[3]) img5 = gr.Image(images[4]) img6 = gr.Image(images[5]) img7 = gr.Image(images[6]) img8 = gr.Image(images[7]) img9 = gr.Image(images[8]) img10 = gr.Image(images[9]) img11 = gr.Image(images[10]) img12 = gr.Image(images[11]) img13 = gr.Image(images[12]) img14 = gr.Image(images[13]) img15 = gr.Image(images[14]) img16 = gr.Image(images[15]) continue_button = gr.Button("Continue") submit_button = gr.Button("Submit", visible=False) finish_button = gr.Button("Finish", visible=False) def update_label(concept_checkboxes, user_state): count = user_state if isinstance(user_state, int) else user_state.value if count < config['dataset'][config['dataset']['name']]['n_classes']: # image examples images = load_example_images(count, data_dir) img1 = gr.Image(images[0]) img2 = gr.Image(images[1]) img3 = gr.Image(images[2]) img4 = gr.Image(images[3]) img5 = gr.Image(images[4]) img6 = gr.Image(images[5]) img7 = gr.Image(images[6]) img8 = gr.Image(images[7]) img9 = gr.Image(images[8]) img10 = gr.Image(images[9]) img11 = gr.Image(images[10]) img12 = gr.Image(images[11]) img13 = gr.Image(images[12]) img14 = gr.Image(images[13]) img15 = gr.Image(images[14]) img16 = gr.Image(images[15]) return img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16 else: return img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16 def update_state(state): count = state if isinstance(state, int) else state.value return gr.State(count + 1) def update_img_label(state): count = state if isinstance(state, int) else state.value return f"### Target class: {class_names[count]}" def update_buttons(): submit_button = gr.Button("Submit", visible=False) continue_button = gr.Button("Continue", visible=True) return continue_button, submit_button def update_continue_button(state): count = state if isinstance(state, int) else state.value max_images = config['dataset'][config['dataset']['name']]['n_classes'] finish_button = gr.Button("Finish", visible=(count == max_images-1)) submit_button = gr.Button("Submit", visible=(count != max_images-1)) continue_button = gr.Button("Continue", visible=False) return continue_button, submit_button, finish_button def update_checkbox(user_state, random_answer_order): count = user_state if isinstance(user_state, int) else user_state.value # get row count from csv row = concepts.iloc[count] keys = concepts.keys() random_ids = generate_random_ids() tmp = [] for i in range(3): t = [] for j in range(3): t.append(int(random_ids[i][j])) tmp.append(t) random_ids = tmp random_order = np.random.permutation(3) print('random_ids:', random_ids) print('random_order:', random_order) random_answer_order[count] = { "random_ids": random_ids, "random_order": random_order } concept_checkboxes = gr.CheckboxGroup( choices = [ (f'{row[keys[random_ids[random_order[0]][0]]]}, {row[keys[random_ids[random_order[0]][1]]]}, {row[keys[random_ids[random_order[0]][2]]]}', int(random_order[0])), (f'{row[keys[random_ids[random_order[1]][0]]]}, {row[keys[random_ids[random_order[1]][1]]]}, {row[keys[random_ids[random_order[1]][2]]]}', int(random_order[1])), (f'{row[keys[random_ids[random_order[2]][0]]]}, {row[keys[random_ids[random_order[2]][1]]]}, {row[keys[random_ids[random_order[2]][2]]]}', int(random_order[2])) ], label=f"Choose the concept set that better describes the class {class_names[count]}", value=None, visible=True ) return random_answer_order, concept_checkboxes def hide_checkbox(): concept_checkboxes = gr.CheckboxGroup( choices = ['c10, c2, c3','c4, c5, c6','c7, c8, c9'], label=f"Choose the concept set that better describes the target class", value=None, visible=False ) return concept_checkboxes def redirect(): pass def save_results(answers, random_answer_order): rand_ids = [random_answer_order[i]['random_ids'] for i in range(len(random_answer_order))] rand_order = [random_answer_order[i]['random_order'] for i in range(len(random_answer_order))] api_token = os.getenv("HFTOKEN") if not api_token: raise ValueError("Hugging Face API token not found. Please set the HF_API_TOKEN environment variable.") json_file_results = config['results']['exp1_dir'] # 'exp1' JSON_DATASET_DIR = Path("json_dataset") JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True) JSON_DATASET_PATH = JSON_DATASET_DIR / f"train-{uuid4()}.json" scheduler = CommitScheduler( repo_id=f"results_{config['dataset']['name']}_{config['results']['exp2_dir']}", repo_type="dataset", folder_path=JSON_DATASET_DIR, path_in_repo="data", token=api_token # Pass the token here ) duration = time.time() - start_time.value info_to_push = { "user_id": time.time(), "answer": {i: answer for i, answer in enumerate(answers)}, "random_ids": {i: [list(elem) for elem in rand_id] for i, rand_id in enumerate(rand_ids)}, # 'random_ids': {0: [[np.int64(3), np.int64(4), np.int64(1)], [np.int64(6), np.int64(3), np.int64(9)], [np.int64(13), np.int64(14), np.int64(5)]], 1: [[np.int64(2), np.int64(1), np.int64(3)], [np.int64(6), np.int64(8), np.int64(5)], [np.int64(11), np.int64(10), np.int64(5)]]} -> it's not serializable "random_order": {i: [int(elem) for elem in rand_o] for i, rand_o in enumerate(rand_order)}, "duration": duration, } print('INFO TO PUSH:', info_to_push) # Save the results into huggingface hub with scheduler.lock: with JSON_DATASET_PATH.open("a") as f: json.dump({ "user_id": info_to_push["user_id"], "answers": info_to_push["answer"], # make it serializable not as it previously defined "random_ids": {i: [list(elem) for elem in rand_id] for i, rand_id in enumerate(rand_ids)}, "random_order": info_to_push["random_order"], "duration": info_to_push["duration"], "datetime": datetime.now().isoformat() }, f) f.write("\n") scheduler.push_to_hub() def check_answer(concept_checkboxes): # check if there are multiple concepts selected, if yes return an error if len(concept_checkboxes) > 1: raise gr.Error("Please select only one concept set") if len(concept_checkboxes) == 0: raise gr.Error("Please select a concept set") def add_answer(concept_checkboxes, answers): answers.append(concept_checkboxes) print('ANSWERS:', answers, concept_checkboxes) return answers submit_button.click( check_answer, inputs=concept_checkboxes ).success( update_state, inputs=user_state, outputs=user_state ).then( add_answer, inputs=[concept_checkboxes, answers], outputs=answers ).then( update_img_label, inputs=user_state, outputs=target_img_label ).then( update_buttons, outputs={continue_button, submit_button} ).then( hide_checkbox, outputs=concept_checkboxes ).then( update_label, inputs=[concept_checkboxes, user_state], outputs={img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16}, ) #.then( # update_checkbox, # outputs=concept_checkboxes #) continue_button.click( update_continue_button, inputs=user_state, outputs={continue_button, submit_button, finish_button} ).then( update_checkbox, inputs=[user_state, random_answer_order], outputs={random_answer_order, concept_checkboxes} ) finish_button.click( check_answer, inputs=concept_checkboxes ).success( update_state, inputs=user_state, outputs=user_state ).then( add_answer, inputs=[concept_checkboxes, answers],outputs=answers ).then( save_results, inputs=[answers, random_answer_order] ).then( redirect, js="window.location = 'https://marcoparola.github.io/saliency-evaluation-app/end'" ) demo.load() demo.launch() if __name__ == "__main__": main()