Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| import subprocess | |
| import uuid | |
| from PIL import Image | |
| import gradio as gr | |
| UPLOAD_DIR = "./sessions" | |
| RESULTS_DIR = "./results" | |
| CHECKPOINTS_DIR = "./checkpoints/SingleImageReflectionRemoval" | |
| SAMPLE_DIR = "./sample_images" | |
| os.makedirs(RESULTS_DIR, exist_ok=True) | |
| os.makedirs(CHECKPOINTS_DIR, exist_ok=True) | |
| os.makedirs(SAMPLE_DIR, exist_ok=True) | |
| from huggingface_hub import hf_hub_download | |
| from shutil import copyfile | |
| REPO_ID = "hasnafk/SingleImageReflectionRemoval" | |
| MODEL_FILE = "310_net_G.pth" | |
| model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE, cache_dir=CHECKPOINTS_DIR) | |
| expected_model_path = os.path.join(CHECKPOINTS_DIR, MODEL_FILE) | |
| if not os.path.exists(expected_model_path): | |
| copyfile(model_path, expected_model_path) | |
| def generate_session_id(): | |
| return str(uuid.uuid4()) | |
| def randomize_file_name(original_name): | |
| extension = os.path.splitext(original_name)[1] | |
| new_name = f"{uuid.uuid4().hex}{extension}" | |
| return new_name | |
| def clear_session_files(session_id): | |
| session_dir = os.path.join(UPLOAD_DIR, session_id) | |
| if os.path.exists(session_dir): | |
| shutil.rmtree(session_dir) | |
| def reflection_removal(input_image, preprocess_type="resize_and_crop"): | |
| if preprocess_type not in ["resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "none"]: | |
| return "Invalid preprocessing type selected. Please choose a valid option." | |
| print("Preprocessing Type:", preprocess_type) | |
| print("Input Image:", input_image) | |
| session_id = generate_session_id() | |
| session_dir = os.path.join(UPLOAD_DIR, session_id) | |
| upload_dir = os.path.join(session_dir, "uploads") | |
| os.makedirs(upload_dir, exist_ok=True) | |
| if not input_image or not os.path.exists(input_image): | |
| return "No image was provided or file was cleared. Please upload a valid image." | |
| randomized_name = randomize_file_name(os.path.basename(input_image)) | |
| file_path = os.path.join(upload_dir, randomized_name) | |
| shutil.copy(input_image, file_path) | |
| input_filename = os.path.splitext(randomized_name)[0] | |
| cmd = [ | |
| "python", "test.py", | |
| "--dataroot", upload_dir, | |
| "--name", "SingleImageReflectionRemoval", | |
| "--model", "test", "--netG", "unet_256", | |
| "--direction", "AtoB", "--dataset_mode", "single", | |
| "--norm", "batch", "--epoch", "310", | |
| "--num_test", "1", | |
| "--gpu_ids", "-1", | |
| "--preprocess", preprocess_type | |
| ] | |
| attempt = 0 | |
| while True: | |
| attempt += 1 | |
| try: | |
| subprocess.run(cmd, check=True) | |
| break | |
| except subprocess.CalledProcessError as e: | |
| cmd = [ | |
| "python", "test.py", | |
| "--dataroot", upload_dir, | |
| "--name", "SingleImageReflectionRemoval", | |
| "--model", "test", "--netG", "unet_256", | |
| "--direction", "AtoB", "--dataset_mode", "single", | |
| "--norm", "batch", "--epoch", "310", | |
| "--num_test", "1", | |
| "--gpu_ids", "-1", | |
| ] | |
| if attempt > 2: | |
| return "No results found. Please try again with a different image." | |
| output_image = None | |
| for root, _, files in os.walk(RESULTS_DIR): | |
| for file in files: | |
| if file.startswith(input_filename) and file.endswith("_fake.png"): | |
| result_path = os.path.join(root, file) | |
| output_image = Image.open(result_path) | |
| if preprocess_type not in ["crop", "none"]: | |
| input_image = Image.open(input_image) | |
| output_image = output_image.resize(input_image.size) | |
| os.remove(result_path) | |
| elif file.startswith(input_filename) and file.endswith("_real.png"): | |
| real_path = os.path.join(root, file) | |
| os.remove(real_path) | |
| clear_session_files(session_id) | |
| if output_image: | |
| return output_image | |
| return "No results found." | |
| def use_sample_image(sample_image_name): | |
| sample_image_path = os.path.join(SAMPLE_DIR, sample_image_name) | |
| if not os.path.exists(sample_image_path): | |
| return "Sample image not found." | |
| return sample_image_path | |
| sample_images = [ | |
| file for file in os.listdir(SAMPLE_DIR) | |
| if file.endswith((".jpg", ".jpeg", ".png")) | |
| ] | |
| preprocess_options = [ | |
| "resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "none" | |
| ] | |
| iface = gr.Interface( | |
| fn=lambda input_image, preprocess_type: reflection_removal(input_image, preprocess_type or "resize_and_crop"), | |
| inputs=[ | |
| gr.Image(type="filepath", label="Upload Image (JPG/PNG)"), | |
| gr.Dropdown(choices=preprocess_options, label="Preprocessing Type", value="resize_and_crop") | |
| ], | |
| outputs=gr.Image(label="Result after Reflection Removal"), | |
| examples=[ | |
| [os.path.join(SAMPLE_DIR, img), "resize_and_crop"] | |
| for img in sample_images | |
| ], | |
| title="Reflection Remover with Pix2Pix", | |
| description="Upload images to remove reflections using a Pix2Pix model. You can also try the sample images below.", | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |