Spaces:
Paused
Paused
| import csv | |
| import json | |
| import os | |
| import pickle | |
| import random | |
| import string | |
| import sys | |
| import time | |
| from glob import glob | |
| import datasets | |
| import gdown | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torchvision | |
| from huggingface_hub import HfApi, login, snapshot_download | |
| from PIL import Image | |
| session_token = os.environ.get("SessionToken") | |
| login(token=session_token) | |
| csv.field_size_limit(sys.maxsize) | |
| np.random.seed(int(time.time())) | |
| with open('./imagenet_hard_nearest_indices.pkl', 'rb') as f: | |
| knn_results = pickle.load(f) | |
| with open("imagenet-labels.json") as f: | |
| wnid_to_label = json.load(f) | |
| with open('id_to_label.json', 'r') as f: | |
| id_to_labels = json.load(f) | |
| bad_items = open('./ex2.txt', 'r').read().split('\n') | |
| bad_items = [x.split('.')[0] for x in bad_items] | |
| bad_items = [int(x) for x in bad_items if x != ''] | |
| # download and extract folders | |
| gdown.cached_download( | |
| url="https://huggingface.co/datasets/taesiri/imagenet_hard_review_samples/resolve/main/data.zip", | |
| path="./data.zip", | |
| quiet=False, | |
| md5="8666a9b361f6eea79878be6c09701def", | |
| ) | |
| # EXTRACT if needed | |
| if not os.path.exists("./imagenet_traning_samples") or not os.path.exists("./knn_cache_for_imagenet_hard"): | |
| torchvision.datasets.utils.extract_archive( | |
| from_path="data.zip", | |
| to_path="./", | |
| remove_finished=False, | |
| ) | |
| imagenet_hard = datasets.load_dataset("taesiri/imagenet-hard", split="validation") | |
| def update_snapshot(): | |
| output_dir = snapshot_download( | |
| repo_id="taesiri/imagenet_hard_review_data", allow_patterns="*.json", repo_type="dataset" | |
| ) | |
| total_size = len(imagenet_hard) | |
| files = glob(f"{output_dir}/*.json") | |
| df = pd.DataFrame() | |
| columns = ["id", "user_id", "time", "decision"] | |
| rows = [] | |
| for file in files: | |
| with open(file) as f: | |
| data = json.load(f) | |
| tdf = [data[x] for x in columns] | |
| # add filename as a column | |
| rows.append(tdf) | |
| df = pd.DataFrame(rows, columns=columns) | |
| return df, total_size | |
| # df = update_snapshot() | |
| NUMBER_OF_IMAGES = 1000 | |
| # Function to sample 10 ids based on their usage count | |
| def sample_ids(df, total_size, sample_size): | |
| id_counts = df['id'].value_counts().to_dict() | |
| all_ids = bad_items | |
| for id in all_ids: | |
| if id not in id_counts: | |
| id_counts[id] = 0 | |
| weights = [id_counts[id] for id in all_ids] | |
| inverse_weights = [1 / (count + 1) for count in weights] | |
| normalized_weights = [w / sum(inverse_weights) for w in inverse_weights] | |
| sampled_ids = np.random.choice(all_ids, size=sample_size, replace=False, p=normalized_weights) | |
| return sampled_ids | |
| def generate_dataset(): | |
| df, total_size = update_snapshot() | |
| random_indices = sample_ids(df, total_size, NUMBER_OF_IMAGES) | |
| random_images = [imagenet_hard[int(i)]["image"] for i in random_indices] | |
| random_gt_ids = [imagenet_hard[int(i)]["label"] for i in random_indices] | |
| random_gt_labels = [imagenet_hard[int(x)]["english_label"] for x in random_indices] | |
| data = [] | |
| for i, image in enumerate(random_images): | |
| data.append( | |
| { | |
| "id": random_indices[i], | |
| "image": image, | |
| "correct_label": random_gt_labels[i], | |
| "original_id": int(random_indices[i]), | |
| } | |
| ) | |
| return data | |
| def string_to_image(text): | |
| text = text.replace('_', ' ').lower().replace(', ', '\n') | |
| # Create a blank white square image | |
| img = np.ones((220, 75, 3)) | |
| # Create a figure and axis object | |
| fig, ax = plt.subplots(figsize=(6, 2.25)) | |
| # Plot the blank white image | |
| ax.imshow(img, extent=[0, 1, 0, 1]) | |
| # Set the text in the center | |
| ax.text(0.5, 0.75, text, fontsize=18, ha='center', va='center') | |
| # Remove the axis labels and ticks | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| ax.set_xticklabels([]) | |
| ax.set_yticklabels([]) | |
| # Remove the axis spines | |
| for spine in ax.spines.values(): | |
| spine.set_visible(False) | |
| # Return the figure | |
| return fig | |
| def label_dist_of_nns(qid): | |
| with open('./trainingset_filenames.json', 'r') as f: | |
| trainingset_filenames = json.load(f) | |
| nns = knn_results[qid][:15] | |
| labels = [wnid_to_label[trainingset_filenames[f"{x}"]] for x in nns] | |
| label_counts = {x: labels.count(x) for x in set(labels)} | |
| # sort by count | |
| label_counts = {k: v for k, v in sorted(label_counts.items(), key=lambda item: item[1], reverse=True)} | |
| # percetage | |
| label_counts = {k: v/len(labels) for k, v in label_counts.items()} | |
| return label_counts | |
| from glob import glob | |
| all_samples = glob('./imagenet_traning_samples/*.JPEG') | |
| qid_to_sample = {int(x.split('/')[-1].split('.')[0].split('_')[0]): x for x in all_samples} | |
| def get_training_samples(qid): | |
| labels_id = imagenet_hard[int(qid)]['label'] | |
| samples = [qid_to_sample[x] for x in labels_id] | |
| return samples | |
| knn_cache_path = "knn_cache_for_imagenet_hard" | |
| imagenet_training_samples_path = "imagenet_traning_samples" | |
| def load_sample(data, current_index): | |
| image_id = data[current_index]["id"] | |
| qimage = data[current_index]["image"] | |
| labels = data[current_index]["correct_label"] | |
| return qimage, labels | |
| # return qimage, neighbors_image, training_samples_image | |
| def update_app(decision, data, current_index, history, username): | |
| if current_index == -1: | |
| data = generate_dataset() | |
| if current_index>=0 and current_index < NUMBER_OF_IMAGES-1: | |
| time_stamp = int(time.time()) | |
| image_id = data[current_index]["id"] | |
| # convert to percentage | |
| dicision_dict = { | |
| "id": int(image_id), | |
| "user_id": username, | |
| "time": time_stamp, | |
| "decision": decision, | |
| } | |
| # upload the decision to the server | |
| temp_filename = f"results_{username}_{time_stamp}.json" | |
| # convert decision_dict to json and save it on the disk | |
| with open(temp_filename, "w") as f: | |
| json.dump(dicision_dict, f) | |
| api = HfApi() | |
| api.upload_file( | |
| path_or_fileobj=temp_filename, | |
| path_in_repo=temp_filename, | |
| repo_id="taesiri/imagenet_hard_review_data", | |
| repo_type="dataset", | |
| ) | |
| os.remove(temp_filename) | |
| elif current_index == NUMBER_OF_IMAGES-1: | |
| return None, None, current_index, history, data, None | |
| current_index += 1 | |
| qimage, labels = load_sample(data, current_index) | |
| image_id = data[current_index]["id"] | |
| training_samples_image = get_training_samples(image_id) | |
| training_samples_image = [Image.open(x).convert('RGB') for x in training_samples_image] | |
| # labels is a list of labels, conver it to a string | |
| labels = ", ".join(labels) | |
| label_plot = string_to_image(labels) | |
| return qimage, label_plot, current_index, history, data, training_samples_image | |
| newcss = ''' | |
| #query_image{ | |
| height: auto !important; | |
| } | |
| #nn_gallery { | |
| height: auto !important; | |
| } | |
| #sample_gallery { | |
| height: auto !important; | |
| } | |
| ''' | |
| with gr.Blocks(css=newcss) as demo: | |
| data_gr = gr.State({}) | |
| current_index = gr.State(-1) | |
| history = gr.State({}) | |
| gr.Markdown("# Cleaning ImageNet-Hard!") | |
| random_str = "".join( | |
| random.choice(string.ascii_lowercase + string.digits) for _ in range(5) | |
| ) | |
| username = gr.Textbox(label="Username", value=f"user-{random_str}") | |
| with gr.Column(): | |
| with gr.Row(): | |
| accept_btn = gr.Button(value="Accept") | |
| myabe_btn = gr.Button(value="Not Sure!") | |
| reject_btn = gr.Button(value="Reject") | |
| with gr.Row(): | |
| query_image = gr.Image(type="pil", label="Query", elem_id="query_image") | |
| with gr.Column(): | |
| label_plot = gr.Plot(label='Is this a correct label for this image?', type='fig') | |
| training_samples = gr.Gallery(type="pil", label="Training samples" , elem_id="sample_gallery") | |
| # with gr.Column(): | |
| # gr.Markdown("## Nearest Neighbors Analysis of the Query (ResNet-50)") | |
| # nn_labels = gr.Label(label="NN-Labels") | |
| # neighbors_image = gr.Image(type="pil", label="Nearest Neighbors", elem_id="nn_gallery") | |
| accept_btn.click( | |
| update_app, | |
| inputs=[accept_btn, data_gr, current_index, history, username], | |
| outputs=[query_image, label_plot, current_index, history, data_gr, training_samples] | |
| ) | |
| myabe_btn.click( | |
| update_app, | |
| inputs=[myabe_btn, data_gr, current_index, history, username], | |
| outputs=[query_image, label_plot, current_index, history, data_gr, training_samples] | |
| ) | |
| reject_btn.click( | |
| update_app, | |
| inputs=[reject_btn, data_gr, current_index, history, username], | |
| outputs=[query_image, label_plot, current_index, history, data_gr, training_samples] | |
| ) | |
| demo.launch() | |