File size: 4,267 Bytes
e546fea 2b4b81f 1a24fbc e546fea 958511f e546fea 958511f 3618356 958511f e546fea 958511f 3e75999 2d64873 2b4b81f 2d64873 1a24fbc 2d64873 f0a6ca3 e546fea 3e75999 1a24fbc 2b4b81f 1a24fbc 2d64873 e546fea 1a24fbc e546fea 1a24fbc 20a2fe0 1a24fbc 958511f 1a24fbc 958511f 1a24fbc 958511f 1a24fbc e546fea 1a24fbc 6c1d97a 1a24fbc 6c1d97a 1a24fbc e546fea 1a24fbc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import gradio as gr
from loadimg import load_img
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import tempfile
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cpu")
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def fn(image):
im = load_img(image, output_type="pil")
im = im.convert("RGB")
origin = im.copy()
image = process(im)
return image
def parse_color(color):
if color.startswith('#'):
hex_color = color.lstrip('#')
r = int(hex_color[0:2], 16)
g = int(hex_color[2:4], 16)
b = int(hex_color[4:6], 16)
elif color.startswith('rgba'):
rgba_values = color.replace('rgba(', '').replace(')', '')
parts = [x.strip() for x in rgba_values.split(',')]
r, g, b = int(float(parts[0])), int(float(parts[1])), int(float(parts[2]))
elif color.startswith('rgb'):
rgb_values = color.replace('rgb(', '').replace(')', '')
r, g, b = [int(float(x.strip())) for x in rgb_values.split(',')]
else:
r, g, b = 255, 255, 255
return (r, g, b, 255)
def process(image):
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to("cpu")
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image, mask
def process_file(f, bg_color):
im = load_img(f, output_type="pil")
im = im.convert("RGB")
transparent_img, mask = process(im)
# With background color
rgba_color = parse_color(bg_color)
background = Image.new("RGBA", im.size, rgba_color)
with_bg = Image.alpha_composite(background, transparent_img)
with_bg_rgb = with_bg.convert("RGB")
bg_png_path = tempfile.mktemp(suffix=".png")
with_bg.save(bg_png_path, "PNG")
bg_jpeg_path = tempfile.mktemp(suffix=".jpeg")
with_bg_rgb.save(bg_jpeg_path, "JPEG")
# Transparent (no background)
trans_png_path = tempfile.mktemp(suffix=".png")
transparent_img.save(trans_png_path, "PNG")
return (with_bg_rgb, bg_png_path, bg_jpeg_path,
transparent_img, trans_png_path)
css = """
.gradio-container h1 {
margin-bottom: 24px;
}
.small-file, .small-file * {
min-height: 0 !important;
height: auto !important;
}
.small-file svg {
display: none !important;
}
"""
with gr.Blocks(css=css, title="Background Remover") as background_remover_app:
gr.Markdown("<h1 style='text-align: center;'>Background Remover</h1>")
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### Input")
input_image = gr.Image(label="Upload an image", type="filepath")
color_picker = gr.ColorPicker(label="Background Color", value="#ffffff")
submit_btn = gr.Button("Submit", variant="primary")
with gr.Column(scale=2):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Output With Background Color")
bg_preview = gr.Image(label="Preview")
bg_png = gr.File(label="Download PNG", elem_classes="small-file")
bg_jpeg = gr.File(label="Download JPEG", elem_classes="small-file")
with gr.Column(scale=1):
gr.Markdown("### Output With Transparent Background")
trans_preview = gr.Image(label="Preview")
trans_png = gr.File(label="Download PNG", elem_classes="small-file")
gr.Examples(
examples=[["butterfly.jpg", "#ffffff"]],
inputs=[input_image, color_picker]
)
submit_btn.click(
fn=process_file,
inputs=[input_image, color_picker],
outputs=[bg_preview, bg_png, bg_jpeg, trans_preview, trans_png]
)
if __name__ == "__main__":
background_remover_app.launch(share=True)
|