File size: 2,241 Bytes
e546fea
 
 
 
 
 
 
2b4b81f
e546fea
958511f
e546fea
958511f
 
 
3618356
958511f
 
 
 
 
 
 
e546fea
 
958511f
 
3e75999
2d64873
2b4b81f
2d64873
 
 
f0a6ca3
e546fea
 
 
 
 
3e75999
2b4b81f
 
3e75999
2b4b81f
 
 
2d64873
 
1cbf784
2d64873
 
 
8be6065
 
2d64873
e546fea
2b4b81f
c368dca
e546fea
2d64873
e546fea
1cbf784
e546fea
20a2fe0
b2b24c7
958511f
521737f
958511f
2d64873
958511f
e546fea
958511f
1cbf784
958511f
e546fea
 
2373e76
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image

torch.set_float32_matmul_precision(["high", "highest"][0])

birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cpu")
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

def fn(image):
    im = load_img(image, output_type="pil")
    im = im.convert("RGB")
    origin = im.copy()
    image = process(im)
    return image

def process(image):
    image_size = image.size
    input_images = transform_image(image).unsqueeze(0).to("cpu")
    # Prediction
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)

    white_background = Image.new("RGBA", image_size, (255, 255, 255, 255))
    image.putalpha(mask)
    combined = Image.alpha_composite(white_background, image)

    return combined
  
def process_file(f):
    name_path = f.rsplit(".",1)[0]+".jpeg"
    im = load_img(f, output_type="pil")
    im = im.convert("RGB")
    transparent = process(im)
    rgb_image = transparent.convert("RGB")  # Ensure the final image is in RGB mode for JPEG
    rgb_image.save(name_path)
    return name_path

slider1 = gr.Image()
slider2 = ImageSlider(label="birefnet", type="pil")
image = gr.Image(label="Upload an image")
image2 = gr.Image(label="Upload an image",type="filepath")
text = gr.Textbox(label="Paste an image URL")
png_file = gr.File(label="output jpeg file")


chameleon = load_img("butterfly.jpg", output_type="pil")

url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"

tab3 = gr.Interface(process_file, inputs=image2, outputs=png_file, examples=["butterfly.jpg"], api_name="png")


demo = gr.TabbedInterface(
    [tab3], ["jpeg"], title="Na Na"
)

if __name__ == "__main__":
    demo.launch(share=True)