File size: 5,169 Bytes
2bb76b3
b79726b
2bb76b3
d3aec7f
2bb76b3
b79726b
2bb76b3
b79726b
2bb76b3
 
9489bae
2bb76b3
 
 
9489bae
2bb76b3
b79726b
 
 
2bb76b3
0b76722
2bb76b3
 
a0c04c5
 
 
02a2aa5
b79726b
 
 
 
 
 
 
a0c04c5
b79726b
 
 
 
 
 
ec61f43
 
 
 
 
7ba731c
b79726b
 
 
d3aec7f
02a2aa5
d3aec7f
 
 
b79726b
 
26e1c44
d3aec7f
b79726b
d3aec7f
2bb76b3
9489bae
d3aec7f
2bb76b3
 
 
 
d3aec7f
95f6ecc
220be7c
2bb76b3
7ba731c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3aec7f
b79726b
c9b5e9c
 
 
9489bae
02a2aa5
60c5500
 
 
 
 
f6fd486
 
 
 
4c9b339
d3aec7f
f6fd486
b79726b
f6fd486
2d422fa
4c9b339
b79726b
 
 
 
 
2bb76b3
9489bae
 
 
 
2bb76b3
b8a1b3c
 
 
 
b79726b
7bd8b18
ec61f43
b79726b
 
 
 
ec61f43
b79726b
 
 
 
95b14fb
b79726b
2bb76b3
 
b79726b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import shutil
import subprocess
import uuid
from PIL import Image
import gradio as gr

UPLOAD_DIR = "./sessions"
RESULTS_DIR = "./results"
CHECKPOINTS_DIR = "./checkpoints/SingleImageReflectionRemoval"
SAMPLE_DIR = "./sample_images"

os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
os.makedirs(SAMPLE_DIR, exist_ok=True)

from huggingface_hub import hf_hub_download
from shutil import copyfile

REPO_ID = "hasnafk/SingleImageReflectionRemoval"
MODEL_FILE = "310_net_G.pth"
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE, cache_dir=CHECKPOINTS_DIR)

expected_model_path = os.path.join(CHECKPOINTS_DIR, MODEL_FILE)
if not os.path.exists(expected_model_path):
    copyfile(model_path, expected_model_path)

def generate_session_id():
    return str(uuid.uuid4())

def randomize_file_name(original_name):
    extension = os.path.splitext(original_name)[1]
    new_name = f"{uuid.uuid4().hex}{extension}"
    return new_name

def clear_session_files(session_id):
    session_dir = os.path.join(UPLOAD_DIR, session_id)
    if os.path.exists(session_dir):
        shutil.rmtree(session_dir)

def reflection_removal(input_image, preprocess_type="resize_and_crop"):
    if preprocess_type not in ["resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "none"]:
        return "Invalid preprocessing type selected. Please choose a valid option."
    
    print("Preprocessing Type:", preprocess_type)
    print("Input Image:", input_image)

    session_id = generate_session_id()
    session_dir = os.path.join(UPLOAD_DIR, session_id)
    upload_dir = os.path.join(session_dir, "uploads")
    os.makedirs(upload_dir, exist_ok=True)

    if not input_image or not os.path.exists(input_image):
        return "No image was provided or file was cleared. Please upload a valid image."

    randomized_name = randomize_file_name(os.path.basename(input_image))
    file_path = os.path.join(upload_dir, randomized_name)
    shutil.copy(input_image, file_path)

    input_filename = os.path.splitext(randomized_name)[0]

    cmd = [
        "python", "test.py",
        "--dataroot", upload_dir,
        "--name", "SingleImageReflectionRemoval",
        "--model", "test", "--netG", "unet_256",
        "--direction", "AtoB", "--dataset_mode", "single",
        "--norm", "batch", "--epoch", "310",
        "--num_test", "1",
        "--gpu_ids", "-1",
        "--preprocess", preprocess_type
    ]
    attempt = 0
    while True:
        attempt += 1
        try:
            subprocess.run(cmd, check=True)
            break
        except subprocess.CalledProcessError as e:
            cmd = [
                "python", "test.py",
                "--dataroot", upload_dir,
                "--name", "SingleImageReflectionRemoval",
                "--model", "test", "--netG", "unet_256",
                "--direction", "AtoB", "--dataset_mode", "single",
                "--norm", "batch", "--epoch", "310",
                "--num_test", "1",
                "--gpu_ids", "-1",
            ]
            if attempt > 2:
                return "No results found. Please try again with a different image."

    output_image = None
    for root, _, files in os.walk(RESULTS_DIR):
        for file in files:
            if file.startswith(input_filename) and file.endswith("_fake.png"):
                result_path = os.path.join(root, file)
                output_image = Image.open(result_path)
                
                if preprocess_type not in ["crop", "none"]:
                    input_image = Image.open(input_image)
                    output_image = output_image.resize(input_image.size)
                
                os.remove(result_path)
            elif file.startswith(input_filename) and file.endswith("_real.png"):
                real_path = os.path.join(root, file)
                os.remove(real_path)

    clear_session_files(session_id)

    if output_image:
        return output_image
    return "No results found."

def use_sample_image(sample_image_name):
    sample_image_path = os.path.join(SAMPLE_DIR, sample_image_name)
    if not os.path.exists(sample_image_path):
        return "Sample image not found."
    return sample_image_path

sample_images = [
    file for file in os.listdir(SAMPLE_DIR)
    if file.endswith((".jpg", ".jpeg", ".png"))
]

preprocess_options = [
    "resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "none"
]

iface = gr.Interface(
    fn=lambda input_image, preprocess_type: reflection_removal(input_image, preprocess_type or "resize_and_crop"),
    inputs=[ 
        gr.Image(type="filepath", label="Upload Image (JPG/PNG)"),
        gr.Dropdown(choices=preprocess_options, label="Preprocessing Type", value="resize_and_crop")
    ],
    outputs=gr.Image(label="Result after Reflection Removal"),
    examples=[ 
        [os.path.join(SAMPLE_DIR, img), "resize_and_crop"]
        for img in sample_images
    ],
    title="Reflection Remover with Pix2Pix",
    description="Upload images to remove reflections using a Pix2Pix model. You can also try the sample images below.",
)

if __name__ == "__main__":
    iface.launch()