BirefNet_rembg / app.py
um41r's picture
Update app.py
634c276 verified
import torch
import gradio as gr
import numpy as np
from PIL import Image
from safetensors.torch import load_file
from birefnet import BiRefNet
from BiRefNet_config import BiRefNetConfig
device = "cpu"
# ======================
# LOAD MODEL
# ======================
config = BiRefNetConfig()
model = BiRefNet(config)
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()
print("✅ BiRefNet Lite loaded")
# ======================
# UTILS
# ======================
def preprocess(img):
img = img.convert("RGB").resize((1024, 1024))
arr = np.array(img).astype(np.float32) / 255.0
arr = arr.transpose(2, 0, 1)
return torch.from_numpy(arr).unsqueeze(0)
@torch.no_grad()
def infer(image):
x = preprocess(image).to(device)
pred = torch.sigmoid(model(x)[0])
mask = (pred.squeeze().cpu().numpy() * 255).astype(np.uint8)
mask = Image.fromarray(mask).resize(image.size)
out = image.convert("RGBA")
out.putalpha(mask)
return out
# ======================
# API + UI FUNCTION
# ======================
def remove_bg(image):
# No API key check for testing
return infer(Image.fromarray(image))
# ======================
# GRADIO INTERFACE
# ======================
demo = gr.Interface(
fn=remove_bg,
inputs=gr.Image(type="numpy", label="Image"),
outputs=gr.Image(type="pil"),
title="BiRefNet Lite – Background Remover",
api_name="remove-bg" # ✅ exposes API endpoint
)
demo.launch(ssr_mode=False)