File size: 1,823 Bytes
db1e9a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d49c6d4
 
db1e9a5
d49c6d4
 
db1e9a5
 
d49c6d4
 
 
 
 
 
db1e9a5
d49c6d4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import os
import requests
from model.u2net import U2NET

MODEL_DIR = "saved_models/u2net"
MODEL_PATH = os.path.join(MODEL_DIR, "u2net.pth")
MODEL_URL = "https://huggingface.co/flashingtt/U-2-Net/resolve/main/u2net.pth"

def download_model():
    if not os.path.exists(MODEL_PATH):
        os.makedirs(MODEL_DIR, exist_ok=True)
        print("Downloading model...")
        r = requests.get(MODEL_URL, stream=True)
        with open(MODEL_PATH, "wb") as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)
        print("Model downloaded.")

download_model()

def load_model():
    net = U2NET(3, 1)
    net.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
    net.eval()
    return net

model = load_model()

def preprocess(image):
    transform = transforms.Compose([
        transforms.Resize((320, 320)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    return transform(image).unsqueeze(0)

def postprocess(mask, original_size):
    mask = mask.squeeze().cpu().data.numpy()
    mask = (mask - mask.min()) / (mask.max() - mask.min())
    mask = Image.fromarray((mask * 255).astype(np.uint8)).resize(original_size, Image.BILINEAR)
    return mask

def remove_background(image):
    input_tensor = preprocess(image)
    with torch.no_grad():
        d1, *_ = model(input_tensor)
    mask = postprocess(d1, image.size)

    image = image.convert("RGBA")
    datas = image.getdata()
    masks = mask.getdata()

    new_data = []
    for item, m in zip(datas, masks):
        new_data.append((item[0], item[1], item[2], m))

    image.putdata(new_data)
    return image