File size: 1,823 Bytes
db1e9a5 d49c6d4 db1e9a5 d49c6d4 db1e9a5 d49c6d4 db1e9a5 d49c6d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import torch
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import os
import requests
from model.u2net import U2NET
MODEL_DIR = "saved_models/u2net"
MODEL_PATH = os.path.join(MODEL_DIR, "u2net.pth")
MODEL_URL = "https://huggingface.co/flashingtt/U-2-Net/resolve/main/u2net.pth"
def download_model():
if not os.path.exists(MODEL_PATH):
os.makedirs(MODEL_DIR, exist_ok=True)
print("Downloading model...")
r = requests.get(MODEL_URL, stream=True)
with open(MODEL_PATH, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
print("Model downloaded.")
download_model()
def load_model():
net = U2NET(3, 1)
net.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
net.eval()
return net
model = load_model()
def preprocess(image):
transform = transforms.Compose([
transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
return transform(image).unsqueeze(0)
def postprocess(mask, original_size):
mask = mask.squeeze().cpu().data.numpy()
mask = (mask - mask.min()) / (mask.max() - mask.min())
mask = Image.fromarray((mask * 255).astype(np.uint8)).resize(original_size, Image.BILINEAR)
return mask
def remove_background(image):
input_tensor = preprocess(image)
with torch.no_grad():
d1, *_ = model(input_tensor)
mask = postprocess(d1, image.size)
image = image.convert("RGBA")
datas = image.getdata()
masks = mask.getdata()
new_data = []
for item, m in zip(datas, masks):
new_data.append((item[0], item[1], item[2], m))
image.putdata(new_data)
return image
|