Delete u2net_utils.py
Browse files- u2net_utils.py +0 -59
u2net_utils.py
DELETED
|
@@ -1,59 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
import os
|
| 3 |
-
import torch
|
| 4 |
-
import requests
|
| 5 |
-
from PIL import Image
|
| 6 |
-
import numpy as np
|
| 7 |
-
from torchvision import transforms
|
| 8 |
-
from model.u2net import U2NET
|
| 9 |
-
|
| 10 |
-
MODEL_DIR = "model"
|
| 11 |
-
MODEL_NAME = "u2net.pth"
|
| 12 |
-
MODEL_PATH = 'model/u2net.pth'
|
| 13 |
-
MODEL_URL = "https://huggingface.co/blurred-png/u2net-pth/resolve/main/u2net.pth"
|
| 14 |
-
|
| 15 |
-
def download_model():
|
| 16 |
-
if not os.path.exists(MODEL_PATH):
|
| 17 |
-
print("Downloading model from Hugging Face...")
|
| 18 |
-
os.makedirs(MODEL_DIR, exist_ok=True)
|
| 19 |
-
r = requests.get(MODEL_URL, stream=True)
|
| 20 |
-
with open(MODEL_PATH, "wb") as f:
|
| 21 |
-
for chunk in r.iter_content(chunk_size=8192):
|
| 22 |
-
if chunk:
|
| 23 |
-
f.write(chunk)
|
| 24 |
-
print("Download complete.")
|
| 25 |
-
|
| 26 |
-
def load_model():
|
| 27 |
-
download_model()
|
| 28 |
-
net = U2NET(3, 1)
|
| 29 |
-
net.load_state_dict(torch.load(MODEL_PATH, map_location="cpu", weights_only=False))
|
| 30 |
-
net.eval()
|
| 31 |
-
return net
|
| 32 |
-
|
| 33 |
-
model = load_model()
|
| 34 |
-
|
| 35 |
-
def preprocess(image):
|
| 36 |
-
transform = transforms.Compose([
|
| 37 |
-
transforms.Resize((320, 320)),
|
| 38 |
-
transforms.ToTensor(),
|
| 39 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 40 |
-
std=[0.229, 0.224, 0.225])
|
| 41 |
-
])
|
| 42 |
-
return transform(image).unsqueeze(0)
|
| 43 |
-
|
| 44 |
-
def postprocess(mask, original_size):
|
| 45 |
-
mask = mask.squeeze().cpu().data.numpy()
|
| 46 |
-
mask = (mask - mask.min()) / (mask.max() - mask.min())
|
| 47 |
-
mask = Image.fromarray((mask * 255).astype(np.uint8)).resize(original_size, Image.BILINEAR)
|
| 48 |
-
return mask
|
| 49 |
-
|
| 50 |
-
def remove_background(image):
|
| 51 |
-
original_size = image.size
|
| 52 |
-
tensor = preprocess(image)
|
| 53 |
-
with torch.no_grad():
|
| 54 |
-
d = model(tensor)[0][:, 0, :, :]
|
| 55 |
-
mask = postprocess(d, original_size)
|
| 56 |
-
image = image.convert("RGBA")
|
| 57 |
-
r, g, b, _ = image.split()
|
| 58 |
-
merged = Image.merge("RGBA", (r, g, b, mask))
|
| 59 |
-
return merged
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|