Spaces:
Runtime error
Runtime error
File size: 2,037 Bytes
5020db1 c4d807b 5020db1 c4d807b 5020db1 c4d807b 5020db1 c4d807b 5020db1 8029c2b 5020db1 c4d807b 5020db1 c4d807b 5020db1 c4d807b 5020db1 c4d807b 5020db1 c4d807b 5020db1 c4d807b 5020db1 c4d807b 5020db1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | import os
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import requests
from BiRefNet import BiRefNet
# 1. Download BiRefNet weights if not present
MODEL_URL = "https://huggingface.co/ZhengPeng7/BiRefNet/resolve/main/BiRefNet.pth"
MODEL_PATH = "BiRefNet.pth"
def download_weights():
if not os.path.exists(MODEL_PATH):
print("Downloading BiRefNet weights...")
r = requests.get(MODEL_URL)
with open(MODEL_PATH, "wb") as f:
f.write(r.content)
print("Done downloading BiRefNet weights.")
# 2. Load BiRefNet model
def load_model():
download_weights()
model = BiRefNet()
state_dict = torch.load(MODEL_PATH, map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
return model
bi_ref_net = load_model()
# 3. Define transforms (assuming model expects 224x224 or similar, adjust if needed)
preprocess = transforms.Compose([
transforms.Resize((224, 224)), # Adjust to BiRefNet input size if different
transforms.ToTensor()
])
def remove_bg(input_image):
# Preprocess image
image = input_image.convert("RGB")
img_tensor = preprocess(image).unsqueeze(0) # Add batch dimension
# Inference (no gradients needed)
with torch.no_grad():
mask = bi_ref_net(img_tensor)[0, 0] # Output mask from model, shape: [H, W]
# Resize mask to original image size, normalize (if needed)
mask_img = transforms.ToPILImage()(mask.cpu().clamp(0, 1))
mask_img = mask_img.resize(image.size, Image.BILINEAR)
# Create RGBA output by setting alpha to mask
result = image.convert("RGBA")
result.putalpha(mask_img)
return result
demo = gr.Interface(
fn=remove_bg,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=gr.Image(type="pil", label="Background Removed (PNG)"),
title="Backdrop Studio - BiRefNet Background Removal",
description="Upload an image to remove the background using BiRefNet AI."
)
if __name__ == "__main__":
demo.launch() |