File size: 2,767 Bytes
e546fea 1c37c95 e546fea 958511f e546fea 958511f b218be6 958511f e546fea 958511f 3e75999 b218be6 9352739 1c37c95 835324b 6f29338 2d64873 958511f e546fea 3e75999 2d64873 b218be6 2d64873 b218be6 2d64873 e546fea 4e70dde b218be6 5f967b0 b218be6 c6d7e41 4e70dde 7b293a2 b218be6 b2b24c7 b218be6 958511f 5f967b0 4e70dde c6d7e41 e546fea 958511f c6d7e41 2dd436d c6d7e41 958511f e546fea b218be6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
import uuid
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def fn(image):
im = load_img(image, output_type="pil")
im = im.convert("RGB")
origin = im.copy()
processed_image = process(im)
name_path = 'bgremove_'+str(uuid.uuid4()) + '.png'
processed_image.save(name_path)
return (processed_image , origin) , name_path
@spaces.GPU
def process(image):
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to("cuda")
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image
def process_file(f):
name_path = f.rsplit(".", 1)[0] + ".png"
im = load_img(f, output_type="pil")
im = im.convert("RGB")
transparent = process(im)
transparent.save(name_path)
return name_path
slider1 = ImageSlider(label="Original vs AI Processed", type="pil")
slider2 = ImageSlider(label="Original vs AI Processed", type="pil")
image_upload = gr.Image(label="Upload an image")
#image_file_upload = gr.Image(label="Upload an image", type="filepath")
url_input = gr.Textbox(label="Paste an image URL")
output_file = gr.File(label="Output PNG File")
download_button = gr.File(label="Download")
download_button2 = gr.File(label="Download")
# Example images
chameleon = load_img("butterfly.jpg", output_type="pil")
url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
tab1 = gr.Interface(fn, inputs=image_upload, outputs= [slider1 , download_button], examples=[chameleon], api_name="image")
tab2 = gr.Interface(fn, inputs=url_input, outputs=[slider2 ,download_button2] , examples=[url_example], api_name="text")
#tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"], api_name="png")
demo = gr.TabbedInterface(
#[tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
[tab1, tab2], ["Image Input", "URL Input"], title="AI Background Remover Pro"
)
if __name__ == "__main__":
demo.launch(show_error=True) |