File size: 2,037 Bytes
5020db1
c4d807b
 
 
 
 
5020db1
c4d807b
5020db1
 
 
c4d807b
 
5020db1
 
 
 
 
 
c4d807b
5020db1
 
8029c2b
5020db1
 
 
 
 
c4d807b
 
 
5020db1
 
 
 
 
 
c4d807b
5020db1
 
 
 
 
 
 
c4d807b
5020db1
 
 
c4d807b
5020db1
 
c4d807b
 
 
 
 
 
5020db1
c4d807b
5020db1
c4d807b
 
 
5020db1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import requests
from BiRefNet import BiRefNet

# 1. Download BiRefNet weights if not present
MODEL_URL = "https://huggingface.co/ZhengPeng7/BiRefNet/resolve/main/BiRefNet.pth"
MODEL_PATH = "BiRefNet.pth"

def download_weights():
    if not os.path.exists(MODEL_PATH):
        print("Downloading BiRefNet weights...")
        r = requests.get(MODEL_URL)
        with open(MODEL_PATH, "wb") as f:
            f.write(r.content)
        print("Done downloading BiRefNet weights.")

# 2. Load BiRefNet model
def load_model():
    download_weights()
    model = BiRefNet()
    state_dict = torch.load(MODEL_PATH, map_location="cpu")
    model.load_state_dict(state_dict)
    model.eval()
    return model

bi_ref_net = load_model()

# 3. Define transforms (assuming model expects 224x224 or similar, adjust if needed)
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),  # Adjust to BiRefNet input size if different
    transforms.ToTensor()
])

def remove_bg(input_image):
    # Preprocess image
    image = input_image.convert("RGB")
    img_tensor = preprocess(image).unsqueeze(0)  # Add batch dimension

    # Inference (no gradients needed)
    with torch.no_grad():
        mask = bi_ref_net(img_tensor)[0, 0]   # Output mask from model, shape: [H, W]

    # Resize mask to original image size, normalize (if needed)
    mask_img = transforms.ToPILImage()(mask.cpu().clamp(0, 1))
    mask_img = mask_img.resize(image.size, Image.BILINEAR)

    # Create RGBA output by setting alpha to mask
    result = image.convert("RGBA")
    result.putalpha(mask_img)
    return result

demo = gr.Interface(
    fn=remove_bg,
    inputs=gr.Image(type="pil", label="Input Image"),
    outputs=gr.Image(type="pil", label="Background Removed (PNG)"),
    title="Backdrop Studio - BiRefNet Background Removal",
    description="Upload an image to remove the background using BiRefNet AI."
)

if __name__ == "__main__":
    demo.launch()