| import datetime |
| from functools import partial |
| import json |
| from pathlib import Path |
| import random |
| import gradio as gr |
| import os |
| import firebase_admin |
| from firebase_admin import db, credentials |
|
|
|
|
| |
| |
| |
|
|
|
|
| NUMBER_OF_IMAGES_PER_ROW = 7 |
| NUMBER_OF_ROWS = 2 |
|
|
|
|
| |
| |
| |
|
|
|
|
| |
| FIREBASE_API_KEY = os.environ['FirebaseSecret'] |
| FIREBASE_URL = os.environ['FirebaseURL'] |
| DATASET = os.environ['Dataset'] |
|
|
| |
| firebase_creds = credentials.Certificate(json.loads(FIREBASE_API_KEY)) |
| firebase_app = firebase_admin.initialize_app(firebase_creds, {'databaseURL': FIREBASE_URL}) |
| firebase_data_ref = db.reference("data") |
|
|
|
|
| |
| |
| |
|
|
|
|
| class Experiment(dict): |
| def __init__(self, dataset, corruption, image_id, corrupted, options, selected_image=None): |
| super().__init__( |
| dataset=dataset, |
| corruption=corruption, |
| image_id=image_id, |
| corrupted=corrupted, |
| options=options, |
| selected_image=selected_image, |
| ) |
| |
| def experiment_to_dict(experiment, skip=False): |
| info = { |
| |
| "dataset": experiment["dataset"], |
| "corruption": experiment["corruption"], |
| "image_number": experiment["image_id"], |
|
|
| |
| "corrupted_filename": experiment["corrupted"]["name"], |
| "options": [img["name"] for img in experiment["options"]], |
| } |
|
|
| if skip: |
| info = { |
| **info, |
| |
| "selected_image": "None", |
| "selected_algo": "None", |
| } |
| else: |
| info = { |
| **info, |
| |
| "selected_image": experiment["options"][experiment["selected_image"]]["name"], |
| "selected_algo": experiment["options"][experiment["selected_image"]]["algo"], |
| } |
| |
| return info |
| |
| def generate_new_experiment() -> Experiment: |
| wanted_corruptions = ["spatter", "impulse_noise", "speckle_noise", "gaussian_noise", "pixelate", "jpeg_compression", "elastic_transform"] |
| corruption = random.choice([f for f in list(Path(f"./images/{DATASET}").glob("*/*")) if f.is_dir() and f.name in wanted_corruptions]) |
| image_id = random.choice(list(corruption.glob("*"))) |
| imgs_to_sample = (NUMBER_OF_IMAGES_PER_ROW * NUMBER_OF_ROWS) // 2 |
|
|
| corrupted_image = {"name": str(random.choice(list(image_id.glob("*corrupted*"))))} |
| sdedit_images = [ |
| {"name": str(img), "algo": "SDEdit"} |
| for img in random.sample(list((image_id / "sde").glob(f"*")), imgs_to_sample) |
| ] |
| odedit_images = [ |
| {"name": str(img), "algo": "ODEdit"} |
| for img in random.sample(list((image_id / "ode").glob(f"*")), imgs_to_sample) |
| ] |
| total_images = sdedit_images + odedit_images |
| random.shuffle(total_images) |
|
|
| return Experiment( |
| DATASET, |
| corruption.name, |
| image_id.name, |
| corrupted_image, |
| total_images, |
| ) |
|
|
| def save(experiment, corrupted_component, *img_components, mode): |
| if mode == "save" and (experiment is None or experiment["selected_image"] is None): |
| gr.Warning("You must select an image before submitting") |
| return [experiment, corrupted_component, *img_components] |
| if mode == "skip": |
| experiment["selected_image"] = None |
|
|
| dict_to_save = { |
| **experiment_to_dict(experiment, skip=(mode=="skip")), |
| "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
| } |
| firebase_data_ref.push(dict_to_save) |
|
|
| print("=====================") |
| print(dict_to_save) |
| print("=====================") |
|
|
| gr.Info("Your choice has been saved to Firebase") |
| return next() |
|
|
|
|
| |
| |
| |
|
|
|
|
| def next(): |
| new_experiment = generate_new_experiment() |
|
|
| new_img_components = [ |
| gr.Image(value=img["name"], label=f"{i}", elem_id="unsel", show_label=False, show_download_button=False, show_share_button=False, interactive=False) |
| for i, img in enumerate(new_experiment["options"]) |
| ] |
| new_corrupted_component = gr.Image(value=new_experiment["corrupted"]["name"], label="corr", elem_id="corrupted", show_label=False, show_download_button=False, show_share_button=False, interactive=False) |
|
|
| return [new_experiment, new_corrupted_component, *new_img_components] |
|
|
| def on_select(evt: gr.SelectData, experiment, *img_components): |
| new_selected = int(evt.target.label) |
|
|
| new_img_components = [ |
| gr.Image(value=img["name"], label=f"{i}", elem_id="unsel", show_label=False, show_download_button=False, show_share_button=False, interactive=False) |
| for i, img in enumerate(experiment["options"]) |
| ] |
| new_img_components[new_selected] = ( |
| gr.Image(value=experiment["options"][new_selected]["name"], label=f"{new_selected}", elem_id="sel", show_label=False, show_download_button=False, show_share_button=False, interactive=False) |
| ) |
|
|
| experiment["selected_image"] = int(evt.target.label) |
|
|
| return [experiment, *new_img_components] |
|
|
| css = """ |
| #unsel {border: solid 5px transparent !important; border-radius: 15px !important; draggable: false} |
| #sel {border: solid 5px #00c0ff !important; border-radius: 15px !important; draggable: false} |
| #corrupted {margin-left: 5%; margin-right: 5%; padding: 0 !important; draggable: false} |
| #reducedHeight {height: 10px !important} |
| #padded {padding-left: 2%; padding-right: 2%} |
| """ |
|
|
| with gr.Blocks(title="Unsupervised Image Editing", css=css) as demo: |
| experiment = gr.State(generate_new_experiment()) |
|
|
| with gr.Row(elem_id="padded"): |
| corrupted_component = gr.Image(label="corr", elem_id="corrupted", show_label=False, show_download_button=False, show_share_button=False, interactive=False) |
| with gr.Column(scale=3, elem_id="padded"): |
| gr.Markdown("<div style='width: 100%'><h1 style='text-align: center; display: inline-block; width: 100%'>The sample on the left is a corrupted image</h1></div>") |
| gr.Markdown("<div style='width: 100%'><h3 style='text-align: center; display: inline-block; width: 100%'>Below are decorrupted versions sampled from various models. Click on the picture you like best.<br/>⚠️Do not pay attention to the background. Consider first fidelity, then quality⚠️</h3></div>") |
| btn_skip = gr.Button("I have no preference") |
| btn_submit = gr.Button("Submit preference") |
|
|
| img_components = [] |
| for row in range(NUMBER_OF_ROWS): |
| with gr.Row(): |
| for col in range(NUMBER_OF_IMAGES_PER_ROW): |
| img_components.append(gr.Image(label=f"{row * NUMBER_OF_IMAGES_PER_ROW + col}", elem_id="unsel", show_label=False, show_download_button=False, show_share_button=False, interactive=False)) |
|
|
| btn_skip.click(partial(save, mode="skip"), [experiment, corrupted_component, *img_components], [experiment, corrupted_component, *img_components]) |
| btn_submit.click(partial(save, mode="save"), [experiment, corrupted_component, *img_components], [experiment, corrupted_component, *img_components]) |
| for img in img_components: |
| img.select(on_select, [experiment, *img_components], [experiment, *img_components], show_progress="hidden") |
|
|
| demo.load(next, None, [experiment, corrupted_component, *img_components]) |
|
|
| demo.launch() |