File size: 4,909 Bytes
143568c
 
 
 
 
4343078
 
 
 
 
bc03e48
4343078
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc03e48
4343078
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143568c
 
 
 
 
 
 
 
 
 
4343078
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143568c
 
 
 
 
 
 
 
 
 
 
 
4343078
 
 
 
 
 
 
 
143568c
4343078
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
Created By: ishwor subedi
Date: 2024-05-19
"""

from PIL import Image
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
import numpy as np
from architecture import BackgroundEnhancer
import gradio as gr
import os


def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
    if len(im.shape) < 3:
        im = im[:, :, np.newaxis]
    im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
    im_tensor = F.interpolate(torch.unsqueeze(im_tensor, 0), size=model_input_size, mode='bilinear').type(torch.uint8)
    image = torch.divide(im_tensor, 255.0)
    image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
    return image


def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
    result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
    ma = torch.max(result)
    mi = torch.min(result)
    result = (result - mi) / (ma - mi)
    im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
    im_array = np.squeeze(im_array)
    return im_array


def example_inference(image):
    orig_im = image.copy()
    orig_image = image.copy()

    model_path = "model.pth"

    net = BackgroundEnhancer()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net.load_state_dict(torch.load(model_path, map_location=device))
    net.to(device)
    net.eval()

    # prepare input
    model_input_size = [1024, 1024]
    orig_im_size = orig_im.size
    orig_im_size = (orig_im_size[1], orig_im_size[0])
    orig_im = np.array(orig_im)
    image = preprocess_image(orig_im, model_input_size).to(device)

    # inference
    result = net(image)

    # post process
    result_image = postprocess_image(result[0][0], orig_im_size)

    # save result
    pil_im = Image.fromarray(result_image)
    no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
    no_bg_image.paste(orig_image, mask=pil_im)

    return no_bg_image


original_image, binary_image = None, None
colors = [Image.open(path) for path in
          [os.path.join("bg_images/color", file) for file in os.listdir("bg_images/color")]]
houses = [Image.open(path) for path in
          [os.path.join("bg_images/house", file) for file in os.listdir("bg_images/house")]]
natures = [Image.open(path) for path in
           [os.path.join("bg_images/nature", file) for file in os.listdir("bg_images/nature")]]
studios = [Image.open(path) for path in
           [os.path.join("bg_images/studio", file) for file in os.listdir("bg_images/studio")]]
walls = [Image.open(path) for path in [os.path.join("bg_images/wall", file) for file in os.listdir("bg_images/wall")]]
woods = [Image.open(path) for path in [os.path.join("bg_images/wood", file) for file in os.listdir("bg_images/wood")]]

with gr.Blocks(
        theme=gr.themes.Default(primary_hue=gr.themes.colors.red, secondary_hue=gr.themes.colors.indigo)) as demo:
    with gr.Row():
        input_img = gr.Image(label="Input", interactive=True, type='pil')
        hidden_img = gr.Image(label="Chosen Background", visible=False)
        output_img = gr.Image(label="Output", interactive=False, type='pil')


    def clearFunc():
        global original_image
        global binary_image


    def update_visibility():
        return gr.Image(visible=True)

        torch.cuda.empty_cache()
        gc.collect()
        return gr.Image(visible=False, value=None)


    with gr.Row():
        examples = gr.Examples(examples=studios, inputs=[hidden_img], label="Studio Backgrounds")

    with gr.Row():
        examples6 = gr.Examples(examples=colors, inputs=[hidden_img], label="Color Backgrounds")

    with gr.Row():
        examples2 = gr.Examples(examples=walls, inputs=[hidden_img], label="Wall Backgrounds")
        examples3 = gr.Examples(examples=natures, inputs=[hidden_img], label="Nature Backgrounds")

    with gr.Row():
        examples4 = gr.Examples(examples=houses, inputs=[hidden_img], label="House Backgrounds")
        examples5 = gr.Examples(examples=woods, inputs=[hidden_img], label="Wood Backgrounds")

    with gr.Row():
        submit = gr.Button("Submit")
        clear = gr.ClearButton(components=[input_img, output_img, hidden_img], value="Reset", variant="stop")


    def generate_img(image, background):
        orig_img = example_inference(image)
        width, height = orig_img.size

        background = Image.fromarray(background).resize((width, height))
        orig_img = Image.fromarray(np.array(orig_img)).resize((width, height))
        background.paste(orig_img, (0, 0), mask=orig_img)
        return background


    hidden_img.change(fn=update_visibility, inputs=[], outputs=[hidden_img])

    submit.click(generate_img, inputs=[input_img, hidden_img], outputs=[output_img])
    clear.click(fn=clearFunc, inputs=[], outputs=[hidden_img])

demo.launch(share=True, debug=True)