Spaces:
Running
Running
File size: 4,044 Bytes
e546fea 67f0d9f e546fea 13a0890 9bb32c5 e546fea 958511f 392a954 958511f 7a6faa9 67f0d9f b218be6 958511f e546fea 5cb992e 6438ac6 e552388 6438ac6 958511f 3e75999 b218be6 60818d2 2d64873 67f0d9f 5cb992e 6438ac6 2d64873 7a6faa9 e546fea 3e75999 2d64873 b218be6 13a0890 6438ac6 b218be6 2d64873 e546fea 12472ea b218be6 20a2fe0 b218be6 b2b24c7 b218be6 958511f b218be6 e546fea 958511f b218be6 958511f e546fea 7455256 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | import gradio as gr
from loadimg import load_img
#import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from typing import Union, Tuple
from PIL import Image
birefnet = AutoModelForImageSegmentation.from_pretrained(
"merve/BiRefNet", low_cpu_mem_usage=False, trust_remote_code=True, torch_dtype=torch.float32, device_map=None
)
birefnet = birefnet.eval()
#birefnet.to("cuda")
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def fn(image: Union[Image.Image, str]) -> Tuple[Image.Image, Image.Image]:
"""
Remove the background from an image and return both the transparent version and the original.
This function performs background removal using a BiRefNet segmentation model. It is intended for use
with image input (either uploaded or from a URL). The function returns a transparent PNG version of the image
with the background removed, along with the original RGB version for comparison.
Args:
image (PIL.Image or str): The input image, either as a PIL object or a filepath/URL string.
Returns:
tuple:
- origin (PIL.Image): The original RGB image, unchanged.
- processed_image (PIL.Image): The input image with the background removed and transparency applied.
"""
im = load_img(image, output_type="pil")
im = im.convert("RGB")
origin = im.copy()
processed_image = process(im)
return (origin, processed_image)
#@spaces.GPU
def process(image: Image.Image) -> Image.Image:
"""
Apply BiRefNet-based image segmentation to remove the background.
This function preprocesses the input image, runs it through a BiRefNet segmentation model to obtain a mask,
and applies the mask as an alpha (transparency) channel to the original image.
Args:
image (PIL.Image): The input RGB image.
Returns:
PIL.Image: The image with the background removed, using the segmentation mask as transparency.
"""
image_size = image.size
input_images = transform_image(image).unsqueeze(0)
with torch.inference_mode():
preds = birefnet(input_images)[-1].sigmoid().detach().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image
def process_file(f: str) -> str:
"""
Load an image file from disk, remove the background, and save the output as a transparent PNG.
Args:
f (str): Filepath of the image to process.
Returns:
str: Path to the saved PNG image with background removed.
"""
name_path = f.rsplit(".", 1)[0] + ".png"
im = load_img(f, output_type="pil")
im = im.convert("RGB")
transparent = process(im)
transparent.save(name_path)
return name_path
slider1 = gr.ImageSlider(label="Processed Image", type="pil", format="png")
slider2 = gr.ImageSlider(label="Processed Image from URL", type="pil", format="png")
image_upload = gr.Image(label="Upload an image")
image_file_upload = gr.Image(label="Upload an image", type="filepath")
url_input = gr.Textbox(label="Paste an image URL")
output_file = gr.File(label="Output PNG File")
# Example images
chameleon = load_img("butterfly.jpg", output_type="pil")
url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
tab1 = gr.Interface(fn, inputs=image_upload, outputs=slider1, examples=[chameleon], api_name="image")
tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example], api_name="text")
tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"], api_name="png")
demo = gr.TabbedInterface(
[tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
)
if __name__ == "__main__":
demo.launch(show_error=True) |