File size: 2,767 Bytes
e546fea
 
 
 
 
 
 
1c37c95
e546fea
958511f
e546fea
958511f
 
 
 
b218be6
958511f
 
 
 
 
 
 
e546fea
 
958511f
 
3e75999
b218be6
9352739
1c37c95
835324b
6f29338
2d64873
 
 
 
958511f
e546fea
 
 
 
 
3e75999
 
2d64873
b218be6
2d64873
b218be6
2d64873
 
 
 
 
e546fea
4e70dde
 
b218be6
5f967b0
b218be6
 
c6d7e41
4e70dde
 
7b293a2
b218be6
b2b24c7
b218be6
958511f
5f967b0
4e70dde
c6d7e41
e546fea
958511f
c6d7e41
2dd436d
c6d7e41
958511f
e546fea
 
b218be6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
import uuid

torch.set_float32_matmul_precision(["high", "highest"][0])

birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")

transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

def fn(image):
    im = load_img(image, output_type="pil")
    im = im.convert("RGB")
    origin = im.copy()
    processed_image = process(im)
    
    name_path = 'bgremove_'+str(uuid.uuid4()) + '.png'
    processed_image.save(name_path)
    return (processed_image , origin) , name_path

@spaces.GPU
def process(image):
    image_size = image.size
    input_images = transform_image(image).unsqueeze(0).to("cuda")
    # Prediction
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    image.putalpha(mask)
    return image

def process_file(f):
    name_path = f.rsplit(".", 1)[0] + ".png"
    im = load_img(f, output_type="pil")
    im = im.convert("RGB")
    transparent = process(im)
    transparent.save(name_path)
    return name_path

slider1 = ImageSlider(label="Original vs AI Processed", type="pil")
slider2 = ImageSlider(label="Original vs AI Processed", type="pil")
image_upload = gr.Image(label="Upload an image")
#image_file_upload = gr.Image(label="Upload an image", type="filepath")
url_input = gr.Textbox(label="Paste an image URL")
output_file = gr.File(label="Output PNG File")
download_button = gr.File(label="Download")
download_button2 = gr.File(label="Download")


# Example images
chameleon = load_img("butterfly.jpg", output_type="pil")
url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"

tab1 = gr.Interface(fn, inputs=image_upload, outputs= [slider1 , download_button], examples=[chameleon], api_name="image")
tab2 = gr.Interface(fn, inputs=url_input, outputs=[slider2 ,download_button2] , examples=[url_example], api_name="text")
#tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"], api_name="png")

demo = gr.TabbedInterface(
    #[tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
    [tab1, tab2], ["Image Input", "URL Input"], title="AI Background Remover Pro"

)

if __name__ == "__main__":
    demo.launch(show_error=True)