ishworrsubedii commited on
Commit
4343078
·
verified ·
1 Parent(s): f182dfb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -0
app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torchvision.transforms.functional import normalize
5
+ import numpy as np
6
+ from briarmbg import BriaRMBG
7
+ import gradio as gr
8
+ import os
9
+
10
+
11
+ def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
12
+ if len(im.shape) < 3:
13
+ im = im[:, :, np.newaxis]
14
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
15
+ im_tensor = F.interpolate(torch.unsqueeze(im_tensor, 0), size=model_input_size, mode='bilinear').type(torch.uint8)
16
+ image = torch.divide(im_tensor, 255.0)
17
+ image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
18
+ return image
19
+
20
+
21
+ def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
22
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
23
+ ma = torch.max(result)
24
+ mi = torch.min(result)
25
+ result = (result - mi) / (ma - mi)
26
+ im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
27
+ im_array = np.squeeze(im_array)
28
+ return im_array
29
+
30
+
31
+ def example_inference(image):
32
+ orig_im = image.copy()
33
+ orig_image = image.copy()
34
+
35
+ model_path = "model.pth"
36
+
37
+ net = BriaRMBG()
38
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+ net.load_state_dict(torch.load(model_path, map_location=device))
40
+ net.to(device)
41
+ net.eval()
42
+
43
+ # prepare input
44
+ model_input_size = [1024, 1024]
45
+ orig_im_size = orig_im.size
46
+ orig_im_size = (orig_im_size[1], orig_im_size[0])
47
+ orig_im = np.array(orig_im)
48
+ image = preprocess_image(orig_im, model_input_size).to(device)
49
+
50
+ # inference
51
+ result = net(image)
52
+
53
+ # post process
54
+ result_image = postprocess_image(result[0][0], orig_im_size)
55
+
56
+ # save result
57
+ pil_im = Image.fromarray(result_image)
58
+ no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
59
+ no_bg_image.paste(orig_image, mask=pil_im)
60
+
61
+ return no_bg_image
62
+
63
+
64
+ original_image, binary_image = None, None
65
+ paths = [os.path.join("bg_images", file) for file in os.listdir("bg_images")]
66
+ images = [Image.open(path) for path in paths]
67
+
68
+ with gr.Blocks(
69
+ theme=gr.themes.Default(primary_hue=gr.themes.colors.red, secondary_hue=gr.themes.colors.indigo)) as demo:
70
+ with gr.Row():
71
+ input_img = gr.Image(label="Input", interactive=True, type='pil')
72
+ hidden_img = gr.Image(label="Chosen Background", visible=False)
73
+ output_img = gr.Image(label="Output", interactive=False, type='pil')
74
+
75
+
76
+ def clearFunc():
77
+ global original_image
78
+ global binary_image
79
+
80
+
81
+ def update_visibility():
82
+ return gr.Image(visible=True)
83
+
84
+ torch.cuda.empty_cache()
85
+ gc.collect()
86
+ return gr.Image(visible=False, value=None)
87
+
88
+
89
+ with gr.Row():
90
+ examples = gr.Examples(examples=images, inputs=[hidden_img])
91
+
92
+ with gr.Row():
93
+ submit = gr.Button("Submit")
94
+ clear = gr.ClearButton(components=[input_img, output_img, hidden_img], value="Reset", variant="stop")
95
+
96
+
97
+ def generate_img(image, background):
98
+ orig_img = example_inference(image)
99
+ height, width = orig_img.size
100
+
101
+ background = Image.fromarray(background).resize((width, height))
102
+ orig_img = Image.fromarray(np.array(orig_img)).resize((width, height))
103
+ background.paste(orig_img, (0, 0), mask=orig_img)
104
+ return background
105
+
106
+
107
+ hidden_img.change(fn=update_visibility, inputs=[], outputs=[hidden_img])
108
+
109
+ submit.click(generate_img, inputs=[input_img, hidden_img], outputs=[output_img])
110
+ clear.click(fn=clearFunc, inputs=[], outputs=[hidden_img])
111
+
112
+ demo.launch(share=True, debug=True)