Spaces:
Sleeping
Sleeping
File size: 2,241 Bytes
e546fea 2b4b81f e546fea 958511f e546fea 958511f 3618356 958511f e546fea 958511f 3e75999 2d64873 2b4b81f 2d64873 f0a6ca3 e546fea 3e75999 2b4b81f 3e75999 2b4b81f 2d64873 1cbf784 2d64873 8be6065 2d64873 e546fea 2b4b81f c368dca e546fea 2d64873 e546fea 1cbf784 e546fea 20a2fe0 b2b24c7 958511f 521737f 958511f 2d64873 958511f e546fea 958511f 1cbf784 958511f e546fea 2373e76 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 | import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cpu")
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def fn(image):
im = load_img(image, output_type="pil")
im = im.convert("RGB")
origin = im.copy()
image = process(im)
return image
def process(image):
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to("cpu")
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
white_background = Image.new("RGBA", image_size, (255, 255, 255, 255))
image.putalpha(mask)
combined = Image.alpha_composite(white_background, image)
return combined
def process_file(f):
name_path = f.rsplit(".",1)[0]+".jpeg"
im = load_img(f, output_type="pil")
im = im.convert("RGB")
transparent = process(im)
rgb_image = transparent.convert("RGB") # Ensure the final image is in RGB mode for JPEG
rgb_image.save(name_path)
return name_path
slider1 = gr.Image()
slider2 = ImageSlider(label="birefnet", type="pil")
image = gr.Image(label="Upload an image")
image2 = gr.Image(label="Upload an image",type="filepath")
text = gr.Textbox(label="Paste an image URL")
png_file = gr.File(label="output jpeg file")
chameleon = load_img("butterfly.jpg", output_type="pil")
url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
tab3 = gr.Interface(process_file, inputs=image2, outputs=png_file, examples=["butterfly.jpg"], api_name="png")
demo = gr.TabbedInterface(
[tab3], ["jpeg"], title="Na Na"
)
if __name__ == "__main__":
demo.launch(share=True)
|