Spaces:
Runtime error
Runtime error
| """Interface for labeling concepts in images. | |
| """ | |
| from typing import Optional | |
| import random | |
| import gradio as gr | |
| from src import global_variables | |
| from src.constants import CONCEPTS, ASSETS_FOLDER, DATASET_NAME | |
| def filter_sample(sample, concepts, username, sample_type): | |
| has_concepts = all([sample[c] for c in concepts]) | |
| if not has_concepts: | |
| return False | |
| if "votes" in sample and username in sample["votes"]: | |
| is_labelled = all([c in sample["votes"][username] for c in CONCEPTS]) | |
| else: | |
| is_labelled = False | |
| if sample_type == "labelled": | |
| return is_labelled | |
| elif sample_type == "unlabelled": | |
| return not is_labelled | |
| else: | |
| raise ValueError(f"Invalid sample type: {sample_type}") | |
| def get_next_image( | |
| split: str, | |
| concepts: list, | |
| sample_type: str, | |
| filtered_indices: dict, | |
| selected_concepts: list, | |
| selected_sample_type: str, | |
| profile: gr.OAuthProfile | |
| ): | |
| username = profile.username | |
| if concepts != selected_concepts or sample_type != selected_sample_type: | |
| for key, values in global_variables.all_metadata.items(): | |
| filtered_indices[key] = [i for i in range(len(values)) if filter_sample(values[i], concepts, username, sample_type)] | |
| selected_concepts = concepts | |
| selected_sample_type = sample_type | |
| try: | |
| sample_idx = random.choice(filtered_indices[split]) | |
| sample = global_variables.all_metadata[split][sample_idx] | |
| image_path = f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/{sample['file_name']}" | |
| try: | |
| username_votes = global_variables.all_votes[sample["id"]][username] | |
| voted_concepts = [c for c in CONCEPTS if username_votes.get(c, False)] | |
| unseen_concepts = [c for c in CONCEPTS if c not in username_votes] | |
| except KeyError: | |
| voted_concepts = [] | |
| unseen_concepts = [] | |
| tie_concepts = [c for c in CONCEPTS if sample[c] is None] | |
| return ( | |
| image_path, | |
| voted_concepts, | |
| f"{split}:{sample_idx}", | |
| sample["class"], | |
| {c: sample[c] for c in CONCEPTS}, | |
| unseen_concepts, | |
| tie_concepts, | |
| filtered_indices, | |
| selected_concepts, | |
| selected_sample_type, | |
| ) | |
| except IndexError: | |
| gr.Warning("No image found for the selected filter.") | |
| return None, None, None, None, None, None, None, filtered_indices, selected_concepts, selected_sample_type | |
| def submit_label( | |
| voted_concepts: list, | |
| current_image: Optional[str], | |
| split, | |
| concepts, | |
| sample_type, | |
| filtered_indices, | |
| selected_concepts, | |
| selected_sample_type, | |
| profile: gr.OAuthProfile | |
| ): | |
| username = profile.username | |
| if current_image is None: | |
| gr.Warning("No image selected.") | |
| return None, None, None, None, None, None, None, filtered_indices, selected_concepts, selected_sample_type | |
| global_variables.update_votes(username, current_image, voted_concepts) | |
| gr.Info("Submit success") | |
| return get_next_image( | |
| split, | |
| concepts, | |
| sample_type, | |
| filtered_indices, | |
| selected_concepts, | |
| selected_sample_type, | |
| profile | |
| ) | |
| def save_current_work( | |
| profile: gr.OAuthProfile, | |
| ): | |
| username = profile.username | |
| global_variables.save_current_work(username) | |
| gr.Info("Save success") | |
| with gr.Blocks() as interface: | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Group(): | |
| gr.Markdown( | |
| "## # Image Selection", | |
| ) | |
| with gr.Row(): | |
| split = gr.Radio( | |
| label="Split", | |
| choices=["train", "test"], | |
| value="train", | |
| ) | |
| sample_type = gr.Radio( | |
| label="Sample Type", | |
| choices=["labelled", "unlabelled"], | |
| value="unlabelled", | |
| ) | |
| concepts = gr.Dropdown( | |
| label="Concepts", | |
| multiselect=True, | |
| choices=CONCEPTS, | |
| ) | |
| with gr.Row(): | |
| next_button = gr.Button( | |
| value="Next", | |
| ) | |
| gr.LoginButton() | |
| submit_button = gr.Button( | |
| value="Local Submit", | |
| ) | |
| with gr.Row(): | |
| save_button = gr.Button( | |
| value="Save", | |
| ) | |
| with gr.Group(): | |
| voted_concepts = gr.CheckboxGroup( | |
| label="Voted Concepts", | |
| choices=CONCEPTS, | |
| ) | |
| unseen_concepts = gr.CheckboxGroup( | |
| label="Previously Unseen Concepts", | |
| choices=CONCEPTS, | |
| ) | |
| tie_concepts = gr.CheckboxGroup( | |
| label="Tie Concepts", | |
| choices=CONCEPTS, | |
| ) | |
| with gr.Group(): | |
| gr.Markdown( | |
| "## # Image Info", | |
| ) | |
| im_class = gr.Textbox( | |
| label="Class", | |
| ) | |
| im_concepts = gr.JSON( | |
| label="Concepts", | |
| ) | |
| with gr.Column(): | |
| image = gr.Image( | |
| label="Image", | |
| ) | |
| current_image = gr.State(None) | |
| filtered_indices = gr.State({ | |
| split: list(range(len(global_variables.all_metadata[split]))) | |
| for split in global_variables.all_metadata | |
| }) | |
| selected_concepts = gr.State([]) | |
| selected_sample_type = gr.State(None) | |
| common_output = [ | |
| image, | |
| voted_concepts, | |
| current_image, | |
| im_class, | |
| im_concepts, | |
| unseen_concepts, | |
| tie_concepts, | |
| filtered_indices, | |
| selected_concepts, | |
| selected_sample_type, | |
| ] | |
| next_button.click( | |
| get_next_image, | |
| inputs=[split, concepts, sample_type, filtered_indices, selected_concepts, selected_sample_type], | |
| outputs=common_output | |
| ) | |
| submit_button.click( | |
| submit_label, | |
| inputs=[voted_concepts, current_image, split, concepts, sample_type, filtered_indices, selected_concepts, selected_sample_type], | |
| outputs=common_output | |
| ) | |
| save_button.click( | |
| save_current_work, | |
| outputs=[image] | |
| ) |