removebg / u2net_utils.py
bluspater's picture
Update u2net_utils.py
d49c6d4 verified
import torch
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import os
import requests
from model.u2net import U2NET
MODEL_DIR = "saved_models/u2net"
MODEL_PATH = os.path.join(MODEL_DIR, "u2net.pth")
MODEL_URL = "https://huggingface.co/flashingtt/U-2-Net/resolve/main/u2net.pth"
def download_model():
if not os.path.exists(MODEL_PATH):
os.makedirs(MODEL_DIR, exist_ok=True)
print("Downloading model...")
r = requests.get(MODEL_URL, stream=True)
with open(MODEL_PATH, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
print("Model downloaded.")
download_model()
def load_model():
net = U2NET(3, 1)
net.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
net.eval()
return net
model = load_model()
def preprocess(image):
transform = transforms.Compose([
transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
return transform(image).unsqueeze(0)
def postprocess(mask, original_size):
mask = mask.squeeze().cpu().data.numpy()
mask = (mask - mask.min()) / (mask.max() - mask.min())
mask = Image.fromarray((mask * 255).astype(np.uint8)).resize(original_size, Image.BILINEAR)
return mask
def remove_background(image):
input_tensor = preprocess(image)
with torch.no_grad():
d1, *_ = model(input_tensor)
mask = postprocess(d1, image.size)
image = image.convert("RGBA")
datas = image.getdata()
masks = mask.getdata()
new_data = []
for item, m in zip(datas, masks):
new_data.append((item[0], item[1], item[2], m))
image.putdata(new_data)
return image