Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import shutil | |
| import os | |
| import subprocess | |
| import uuid | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| from shutil import copyfile | |
| RESULTS_DIR = "./results" | |
| CHECKPOINTS_DIR = "./checkpoints/SingleImageReflectionRemoval" | |
| SAMPLE_DIR = "./sample_images" | |
| os.makedirs(RESULTS_DIR, exist_ok=True) | |
| os.makedirs(CHECKPOINTS_DIR, exist_ok=True) | |
| os.makedirs(SAMPLE_DIR, exist_ok=True) | |
| REPO_ID = "hasnafk/SingleImageReflectionRemoval" | |
| MODEL_FILE = "310_net_G.pth" | |
| model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE, cache_dir=CHECKPOINTS_DIR) | |
| expected_model_path = os.path.join(CHECKPOINTS_DIR, MODEL_FILE) | |
| if not os.path.exists(expected_model_path): | |
| copyfile(model_path, expected_model_path) | |
| def reflection_removal(input_image, preprocess_type="resize_and_crop", session_id=None): | |
| if not session_id: | |
| return "Session ID missing. Please try again." | |
| upload_dir = os.path.join("./sessions", session_id, "uploads") | |
| os.makedirs(upload_dir, exist_ok=True) | |
| if not input_image or not os.path.exists(input_image): | |
| return "No image was provided or file was cleared. Please upload a valid image." | |
| file_path = os.path.join(upload_dir, os.path.basename(input_image)) | |
| shutil.copy(input_image, file_path) | |
| input_filename = os.path.splitext(os.path.basename(file_path))[0] | |
| cmd = [ | |
| "python", "test.py", | |
| "--dataroot", upload_dir, | |
| "--name", "SingleImageReflectionRemoval", | |
| "--model", "test", "--netG", "unet_256", | |
| "--direction", "AtoB", "--dataset_mode", "single", | |
| "--norm", "batch", "--epoch", "310", | |
| "--num_test", "1", | |
| "--gpu_ids", "-1", | |
| "--preprocess", preprocess_type | |
| ] | |
| subprocess.run(cmd, check=True) | |
| for root, _, files in os.walk(RESULTS_DIR): | |
| for file in files: | |
| if file.startswith(input_filename) and file.endswith("_fake.png"): | |
| result_path = os.path.join(root, file) | |
| output_image = Image.open(result_path) | |
| clear_session_files(session_id) | |
| return output_image | |
| clear_session_files(session_id) | |
| return "No results found." | |
| def clear_session_files(session_id): | |
| session_dir = os.path.join("./sessions", session_id) | |
| if os.path.exists(session_dir): | |
| shutil.rmtree(session_dir) | |
| print(f"Session {session_id} files cleared.") | |
| def clear_action(session_id=None): | |
| if session_id: | |
| clear_session_files(session_id) | |
| return "Upload cleared!" | |
| return "No session found to clear." | |
| sample_images = [ | |
| file for file in os.listdir(SAMPLE_DIR) | |
| if file.endswith((".jpg", ".jpeg", ".png")) | |
| ] | |
| preprocess_options = [ | |
| "resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "none" | |
| ] | |
| def session_interface(): | |
| session_id = str(uuid.uuid4()) | |
| return gr.Interface( | |
| fn=lambda img, prep: reflection_removal(img, prep, session_id), | |
| inputs=[ | |
| gr.Image(type="filepath", label="Upload Image (JPG/PNG)", interactive=True), | |
| gr.Dropdown(choices=preprocess_options, label="Preprocessing Type", value="resize_and_crop") | |
| ], | |
| outputs=gr.Image(label="Results after Reflection Removal"), | |
| examples=[ | |
| [os.path.join("sample_images", img), "resize_and_crop"] | |
| for img in os.listdir("sample_images") if img.endswith((".jpg", ".jpeg", ".png")) | |
| ], | |
| title="Reflection Remover with Pix2Pix", | |
| description="Upload images to remove reflections using a Pix2Pix model. You can also try the sample images below.", | |
| allow_flagging="never", | |
| live=False | |
| ) | |
| os.makedirs("./sessions", exist_ok=True) | |
| if __name__ == "__main__": | |
| session_interface().launch() | |