File size: 4,267 Bytes
e546fea
 
 
 
 
2b4b81f
1a24fbc
e546fea
958511f
e546fea
958511f
 
 
3618356
958511f
 
 
 
 
 
 
e546fea
 
958511f
 
3e75999
2d64873
2b4b81f
2d64873
1a24fbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d64873
 
f0a6ca3
e546fea
 
 
 
3e75999
 
1a24fbc
2b4b81f
1a24fbc
2d64873
 
e546fea
1a24fbc
e546fea
1a24fbc
 
 
 
 
20a2fe0
1a24fbc
 
958511f
1a24fbc
 
958511f
1a24fbc
 
 
958511f
1a24fbc
 
e546fea
1a24fbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c1d97a
1a24fbc
 
 
 
 
6c1d97a
1a24fbc
 
 
 
 
 
 
 
 
 
 
 
e546fea
 
1a24fbc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr
from loadimg import load_img
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import tempfile

torch.set_float32_matmul_precision(["high", "highest"][0])

birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cpu")
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

def fn(image):
    im = load_img(image, output_type="pil")
    im = im.convert("RGB")
    origin = im.copy()
    image = process(im)
    return image

def parse_color(color):
    if color.startswith('#'):
        hex_color = color.lstrip('#')
        r = int(hex_color[0:2], 16)
        g = int(hex_color[2:4], 16)
        b = int(hex_color[4:6], 16)
    elif color.startswith('rgba'):
        rgba_values = color.replace('rgba(', '').replace(')', '')
        parts = [x.strip() for x in rgba_values.split(',')]
        r, g, b = int(float(parts[0])), int(float(parts[1])), int(float(parts[2]))
    elif color.startswith('rgb'):
        rgb_values = color.replace('rgb(', '').replace(')', '')
        r, g, b = [int(float(x.strip())) for x in rgb_values.split(',')]
    else:
        r, g, b = 255, 255, 255
    return (r, g, b, 255)

def process(image):
    image_size = image.size
    input_images = transform_image(image).unsqueeze(0).to("cpu")
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    image.putalpha(mask)
    return image, mask

def process_file(f, bg_color):
    im = load_img(f, output_type="pil")
    im = im.convert("RGB")

    transparent_img, mask = process(im)

    # With background color
    rgba_color = parse_color(bg_color)
    background = Image.new("RGBA", im.size, rgba_color)
    with_bg = Image.alpha_composite(background, transparent_img)
    with_bg_rgb = with_bg.convert("RGB")

    bg_png_path = tempfile.mktemp(suffix=".png")
    with_bg.save(bg_png_path, "PNG")

    bg_jpeg_path = tempfile.mktemp(suffix=".jpeg")
    with_bg_rgb.save(bg_jpeg_path, "JPEG")

    # Transparent (no background)
    trans_png_path = tempfile.mktemp(suffix=".png")
    transparent_img.save(trans_png_path, "PNG")

    return (with_bg_rgb, bg_png_path, bg_jpeg_path,
            transparent_img, trans_png_path)

css = """
.gradio-container h1 {
    margin-bottom: 24px;
}
.small-file, .small-file * {
    min-height: 0 !important;
    height: auto !important;
}
.small-file svg {
    display: none !important;
}
"""

with gr.Blocks(css=css, title="Background Remover") as background_remover_app:
    gr.Markdown("<h1 style='text-align: center;'>Background Remover</h1>")

    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown("### Input")
            input_image = gr.Image(label="Upload an image", type="filepath")
            color_picker = gr.ColorPicker(label="Background Color", value="#ffffff")
            submit_btn = gr.Button("Submit", variant="primary")

        with gr.Column(scale=2):
            with gr.Row():
                with gr.Column(scale=1):
                    gr.Markdown("### Output With Background Color")
                    bg_preview = gr.Image(label="Preview")
                    bg_png = gr.File(label="Download PNG", elem_classes="small-file")
                    bg_jpeg = gr.File(label="Download JPEG", elem_classes="small-file")

                with gr.Column(scale=1):
                    gr.Markdown("### Output With Transparent Background")
                    trans_preview = gr.Image(label="Preview")
                    trans_png = gr.File(label="Download PNG", elem_classes="small-file")

    gr.Examples(
        examples=[["butterfly.jpg", "#ffffff"]],
        inputs=[input_image, color_picker]
    )

    submit_btn.click(
        fn=process_file,
        inputs=[input_image, color_picker],
        outputs=[bg_preview, bg_png, bg_jpeg, trans_preview, trans_png]
    )

if __name__ == "__main__":
    background_remover_app.launch(share=True)