Spaces:
Sleeping
Sleeping
File size: 5,169 Bytes
2bb76b3 b79726b 2bb76b3 d3aec7f 2bb76b3 b79726b 2bb76b3 b79726b 2bb76b3 9489bae 2bb76b3 9489bae 2bb76b3 b79726b 2bb76b3 0b76722 2bb76b3 a0c04c5 02a2aa5 b79726b a0c04c5 b79726b ec61f43 7ba731c b79726b d3aec7f 02a2aa5 d3aec7f b79726b 26e1c44 d3aec7f b79726b d3aec7f 2bb76b3 9489bae d3aec7f 2bb76b3 d3aec7f 95f6ecc 220be7c 2bb76b3 7ba731c d3aec7f b79726b c9b5e9c 9489bae 02a2aa5 60c5500 f6fd486 4c9b339 d3aec7f f6fd486 b79726b f6fd486 2d422fa 4c9b339 b79726b 2bb76b3 9489bae 2bb76b3 b8a1b3c b79726b 7bd8b18 ec61f43 b79726b ec61f43 b79726b 95b14fb b79726b 2bb76b3 b79726b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import os
import shutil
import subprocess
import uuid
from PIL import Image
import gradio as gr
UPLOAD_DIR = "./sessions"
RESULTS_DIR = "./results"
CHECKPOINTS_DIR = "./checkpoints/SingleImageReflectionRemoval"
SAMPLE_DIR = "./sample_images"
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
os.makedirs(SAMPLE_DIR, exist_ok=True)
from huggingface_hub import hf_hub_download
from shutil import copyfile
REPO_ID = "hasnafk/SingleImageReflectionRemoval"
MODEL_FILE = "310_net_G.pth"
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE, cache_dir=CHECKPOINTS_DIR)
expected_model_path = os.path.join(CHECKPOINTS_DIR, MODEL_FILE)
if not os.path.exists(expected_model_path):
copyfile(model_path, expected_model_path)
def generate_session_id():
return str(uuid.uuid4())
def randomize_file_name(original_name):
extension = os.path.splitext(original_name)[1]
new_name = f"{uuid.uuid4().hex}{extension}"
return new_name
def clear_session_files(session_id):
session_dir = os.path.join(UPLOAD_DIR, session_id)
if os.path.exists(session_dir):
shutil.rmtree(session_dir)
def reflection_removal(input_image, preprocess_type="resize_and_crop"):
if preprocess_type not in ["resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "none"]:
return "Invalid preprocessing type selected. Please choose a valid option."
print("Preprocessing Type:", preprocess_type)
print("Input Image:", input_image)
session_id = generate_session_id()
session_dir = os.path.join(UPLOAD_DIR, session_id)
upload_dir = os.path.join(session_dir, "uploads")
os.makedirs(upload_dir, exist_ok=True)
if not input_image or not os.path.exists(input_image):
return "No image was provided or file was cleared. Please upload a valid image."
randomized_name = randomize_file_name(os.path.basename(input_image))
file_path = os.path.join(upload_dir, randomized_name)
shutil.copy(input_image, file_path)
input_filename = os.path.splitext(randomized_name)[0]
cmd = [
"python", "test.py",
"--dataroot", upload_dir,
"--name", "SingleImageReflectionRemoval",
"--model", "test", "--netG", "unet_256",
"--direction", "AtoB", "--dataset_mode", "single",
"--norm", "batch", "--epoch", "310",
"--num_test", "1",
"--gpu_ids", "-1",
"--preprocess", preprocess_type
]
attempt = 0
while True:
attempt += 1
try:
subprocess.run(cmd, check=True)
break
except subprocess.CalledProcessError as e:
cmd = [
"python", "test.py",
"--dataroot", upload_dir,
"--name", "SingleImageReflectionRemoval",
"--model", "test", "--netG", "unet_256",
"--direction", "AtoB", "--dataset_mode", "single",
"--norm", "batch", "--epoch", "310",
"--num_test", "1",
"--gpu_ids", "-1",
]
if attempt > 2:
return "No results found. Please try again with a different image."
output_image = None
for root, _, files in os.walk(RESULTS_DIR):
for file in files:
if file.startswith(input_filename) and file.endswith("_fake.png"):
result_path = os.path.join(root, file)
output_image = Image.open(result_path)
if preprocess_type not in ["crop", "none"]:
input_image = Image.open(input_image)
output_image = output_image.resize(input_image.size)
os.remove(result_path)
elif file.startswith(input_filename) and file.endswith("_real.png"):
real_path = os.path.join(root, file)
os.remove(real_path)
clear_session_files(session_id)
if output_image:
return output_image
return "No results found."
def use_sample_image(sample_image_name):
sample_image_path = os.path.join(SAMPLE_DIR, sample_image_name)
if not os.path.exists(sample_image_path):
return "Sample image not found."
return sample_image_path
sample_images = [
file for file in os.listdir(SAMPLE_DIR)
if file.endswith((".jpg", ".jpeg", ".png"))
]
preprocess_options = [
"resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "none"
]
iface = gr.Interface(
fn=lambda input_image, preprocess_type: reflection_removal(input_image, preprocess_type or "resize_and_crop"),
inputs=[
gr.Image(type="filepath", label="Upload Image (JPG/PNG)"),
gr.Dropdown(choices=preprocess_options, label="Preprocessing Type", value="resize_and_crop")
],
outputs=gr.Image(label="Result after Reflection Removal"),
examples=[
[os.path.join(SAMPLE_DIR, img), "resize_and_crop"]
for img in sample_images
],
title="Reflection Remover with Pix2Pix",
description="Upload images to remove reflections using a Pix2Pix model. You can also try the sample images below.",
)
if __name__ == "__main__":
iface.launch()
|