File size: 4,909 Bytes
143568c 4343078 bc03e48 4343078 bc03e48 4343078 143568c 4343078 143568c 4343078 143568c 4343078 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
"""
Created By: ishwor subedi
Date: 2024-05-19
"""
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
import numpy as np
from architecture import BackgroundEnhancer
import gradio as gr
import os
def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
if len(im.shape) < 3:
im = im[:, :, np.newaxis]
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
im_tensor = F.interpolate(torch.unsqueeze(im_tensor, 0), size=model_input_size, mode='bilinear').type(torch.uint8)
image = torch.divide(im_tensor, 255.0)
image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
return image
def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi)
im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
im_array = np.squeeze(im_array)
return im_array
def example_inference(image):
orig_im = image.copy()
orig_image = image.copy()
model_path = "model.pth"
net = BackgroundEnhancer()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.load_state_dict(torch.load(model_path, map_location=device))
net.to(device)
net.eval()
# prepare input
model_input_size = [1024, 1024]
orig_im_size = orig_im.size
orig_im_size = (orig_im_size[1], orig_im_size[0])
orig_im = np.array(orig_im)
image = preprocess_image(orig_im, model_input_size).to(device)
# inference
result = net(image)
# post process
result_image = postprocess_image(result[0][0], orig_im_size)
# save result
pil_im = Image.fromarray(result_image)
no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
no_bg_image.paste(orig_image, mask=pil_im)
return no_bg_image
original_image, binary_image = None, None
colors = [Image.open(path) for path in
[os.path.join("bg_images/color", file) for file in os.listdir("bg_images/color")]]
houses = [Image.open(path) for path in
[os.path.join("bg_images/house", file) for file in os.listdir("bg_images/house")]]
natures = [Image.open(path) for path in
[os.path.join("bg_images/nature", file) for file in os.listdir("bg_images/nature")]]
studios = [Image.open(path) for path in
[os.path.join("bg_images/studio", file) for file in os.listdir("bg_images/studio")]]
walls = [Image.open(path) for path in [os.path.join("bg_images/wall", file) for file in os.listdir("bg_images/wall")]]
woods = [Image.open(path) for path in [os.path.join("bg_images/wood", file) for file in os.listdir("bg_images/wood")]]
with gr.Blocks(
theme=gr.themes.Default(primary_hue=gr.themes.colors.red, secondary_hue=gr.themes.colors.indigo)) as demo:
with gr.Row():
input_img = gr.Image(label="Input", interactive=True, type='pil')
hidden_img = gr.Image(label="Chosen Background", visible=False)
output_img = gr.Image(label="Output", interactive=False, type='pil')
def clearFunc():
global original_image
global binary_image
def update_visibility():
return gr.Image(visible=True)
torch.cuda.empty_cache()
gc.collect()
return gr.Image(visible=False, value=None)
with gr.Row():
examples = gr.Examples(examples=studios, inputs=[hidden_img], label="Studio Backgrounds")
with gr.Row():
examples6 = gr.Examples(examples=colors, inputs=[hidden_img], label="Color Backgrounds")
with gr.Row():
examples2 = gr.Examples(examples=walls, inputs=[hidden_img], label="Wall Backgrounds")
examples3 = gr.Examples(examples=natures, inputs=[hidden_img], label="Nature Backgrounds")
with gr.Row():
examples4 = gr.Examples(examples=houses, inputs=[hidden_img], label="House Backgrounds")
examples5 = gr.Examples(examples=woods, inputs=[hidden_img], label="Wood Backgrounds")
with gr.Row():
submit = gr.Button("Submit")
clear = gr.ClearButton(components=[input_img, output_img, hidden_img], value="Reset", variant="stop")
def generate_img(image, background):
orig_img = example_inference(image)
width, height = orig_img.size
background = Image.fromarray(background).resize((width, height))
orig_img = Image.fromarray(np.array(orig_img)).resize((width, height))
background.paste(orig_img, (0, 0), mask=orig_img)
return background
hidden_img.change(fn=update_visibility, inputs=[], outputs=[hidden_img])
submit.click(generate_img, inputs=[input_img, hidden_img], outputs=[output_img])
clear.click(fn=clearFunc, inputs=[], outputs=[hidden_img])
demo.launch(share=True, debug=True)
|