chuuhtetnaing commited on
Commit
1a24fbc
·
1 Parent(s): eb65e05

improved UI features for the gradio

Browse files
Files changed (2) hide show
  1. .gitignore +3 -0
  2. app.py +82 -27
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .venv
2
+ .idea
3
+ .gradio
app.py CHANGED
@@ -1,11 +1,10 @@
1
  import gradio as gr
2
- from gradio_imageslider import ImageSlider
3
  from loadimg import load_img
4
- import spaces
5
  from transformers import AutoModelForImageSegmentation
6
  import torch
7
  from torchvision import transforms
8
  from PIL import Image
 
9
 
10
  torch.set_float32_matmul_precision(["high", "highest"][0])
11
 
@@ -28,49 +27,105 @@ def fn(image):
28
  image = process(im)
29
  return image
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def process(image):
32
  image_size = image.size
33
  input_images = transform_image(image).unsqueeze(0).to("cpu")
34
- # Prediction
35
  with torch.no_grad():
36
  preds = birefnet(input_images)[-1].sigmoid().cpu()
37
  pred = preds[0].squeeze()
38
  pred_pil = transforms.ToPILImage()(pred)
39
  mask = pred_pil.resize(image_size)
40
-
41
- white_background = Image.new("RGBA", image_size, (255, 255, 255, 255))
42
  image.putalpha(mask)
43
- combined = Image.alpha_composite(white_background, image)
44
 
45
- return combined
46
-
47
- def process_file(f):
48
- name_path = f.rsplit(".",1)[0]+".jpeg"
49
  im = load_img(f, output_type="pil")
50
  im = im.convert("RGB")
51
- transparent = process(im)
52
- rgb_image = transparent.convert("RGB") # Ensure the final image is in RGB mode for JPEG
53
- rgb_image.save(name_path)
54
- return name_path
55
 
56
- slider1 = gr.Image()
57
- slider2 = ImageSlider(label="birefnet", type="pil")
58
- image = gr.Image(label="Upload an image")
59
- image2 = gr.Image(label="Upload an image",type="filepath")
60
- text = gr.Textbox(label="Paste an image URL")
61
- png_file = gr.File(label="output jpeg file")
62
 
 
 
 
 
 
63
 
64
- chameleon = load_img("butterfly.jpg", output_type="pil")
 
65
 
66
- url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
 
67
 
68
- tab3 = gr.Interface(process_file, inputs=image2, outputs=png_file, examples=["butterfly.jpg"], api_name="png")
 
 
69
 
 
 
70
 
71
- demo = gr.TabbedInterface(
72
- [tab3], ["jpeg"], title="Na Na"
73
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  if __name__ == "__main__":
76
- demo.launch(share=True)
 
1
  import gradio as gr
 
2
  from loadimg import load_img
 
3
  from transformers import AutoModelForImageSegmentation
4
  import torch
5
  from torchvision import transforms
6
  from PIL import Image
7
+ import tempfile
8
 
9
  torch.set_float32_matmul_precision(["high", "highest"][0])
10
 
 
27
  image = process(im)
28
  return image
29
 
30
+ def parse_color(color):
31
+ if color.startswith('#'):
32
+ hex_color = color.lstrip('#')
33
+ r = int(hex_color[0:2], 16)
34
+ g = int(hex_color[2:4], 16)
35
+ b = int(hex_color[4:6], 16)
36
+ elif color.startswith('rgba'):
37
+ rgba_values = color.replace('rgba(', '').replace(')', '')
38
+ parts = [x.strip() for x in rgba_values.split(',')]
39
+ r, g, b = int(float(parts[0])), int(float(parts[1])), int(float(parts[2]))
40
+ elif color.startswith('rgb'):
41
+ rgb_values = color.replace('rgb(', '').replace(')', '')
42
+ r, g, b = [int(float(x.strip())) for x in rgb_values.split(',')]
43
+ else:
44
+ r, g, b = 255, 255, 255
45
+ return (r, g, b, 255)
46
+
47
  def process(image):
48
  image_size = image.size
49
  input_images = transform_image(image).unsqueeze(0).to("cpu")
 
50
  with torch.no_grad():
51
  preds = birefnet(input_images)[-1].sigmoid().cpu()
52
  pred = preds[0].squeeze()
53
  pred_pil = transforms.ToPILImage()(pred)
54
  mask = pred_pil.resize(image_size)
 
 
55
  image.putalpha(mask)
56
+ return image, mask
57
 
58
+ def process_file(f, bg_color):
 
 
 
59
  im = load_img(f, output_type="pil")
60
  im = im.convert("RGB")
 
 
 
 
61
 
62
+ transparent_img, mask = process(im)
 
 
 
 
 
63
 
64
+ # With background color
65
+ rgba_color = parse_color(bg_color)
66
+ background = Image.new("RGBA", im.size, rgba_color)
67
+ with_bg = Image.alpha_composite(background, transparent_img)
68
+ with_bg_rgb = with_bg.convert("RGB")
69
 
70
+ bg_png_path = tempfile.mktemp(suffix=".png")
71
+ with_bg.save(bg_png_path, "PNG")
72
 
73
+ bg_jpeg_path = tempfile.mktemp(suffix=".jpeg")
74
+ with_bg_rgb.save(bg_jpeg_path, "JPEG")
75
 
76
+ # Transparent (no background)
77
+ trans_png_path = tempfile.mktemp(suffix=".png")
78
+ transparent_img.save(trans_png_path, "PNG")
79
 
80
+ return (with_bg_rgb, bg_png_path, bg_jpeg_path,
81
+ transparent_img, trans_png_path)
82
 
83
+ css = """
84
+ .gradio-container h1 {
85
+ margin-bottom: 24px;
86
+ }
87
+ .small-file, .small-file * {
88
+ min-height: 0 !important;
89
+ height: auto !important;
90
+ }
91
+ .small-file svg {
92
+ display: none !important;
93
+ }
94
+ """
95
+
96
+ with gr.Blocks(css=css, title="Background Remover") as background_remover_app:
97
+ gr.Markdown("<h1 style='text-align: center;'>Background Remover</h1>")
98
+
99
+ with gr.Row():
100
+ with gr.Column(scale=2):
101
+ gr.Markdown("### Input")
102
+ input_image = gr.Image(label="Upload an image", type="filepath")
103
+ color_picker = gr.ColorPicker(label="Background Color", value="#ffffff")
104
+ submit_btn = gr.Button("Submit", variant="primary")
105
+
106
+ with gr.Column(scale=2):
107
+ with gr.Row():
108
+ with gr.Column(scale=1):
109
+ gr.Markdown("### Output With Background Color")
110
+ bg_preview = gr.Image(label="Preview", format="png")
111
+ bg_png = gr.File(label="Download PNG", elem_classes="small-file")
112
+ bg_jpeg = gr.File(label="Download JPEG", elem_classes="small-file")
113
+
114
+ with gr.Column(scale=1):
115
+ gr.Markdown("### Output With Transparent Background")
116
+ trans_preview = gr.Image(label="Preview", format="png")
117
+ trans_png = gr.File(label="Download PNG", elem_classes="small-file")
118
+
119
+ gr.Examples(
120
+ examples=[["butterfly.jpg", "#ffffff"]],
121
+ inputs=[input_image, color_picker]
122
+ )
123
+
124
+ submit_btn.click(
125
+ fn=process_file,
126
+ inputs=[input_image, color_picker],
127
+ outputs=[bg_preview, bg_png, bg_jpeg, trans_preview, trans_png]
128
+ )
129
 
130
  if __name__ == "__main__":
131
+ background_remover_app.launch(share=True)