bluspater commited on
Commit
202082b
·
verified ·
1 Parent(s): 120846b

Delete u2net_utils.py

Browse files
Files changed (1) hide show
  1. u2net_utils.py +0 -59
u2net_utils.py DELETED
@@ -1,59 +0,0 @@
1
-
2
- import os
3
- import torch
4
- import requests
5
- from PIL import Image
6
- import numpy as np
7
- from torchvision import transforms
8
- from model.u2net import U2NET
9
-
10
- MODEL_DIR = "model"
11
- MODEL_NAME = "u2net.pth"
12
- MODEL_PATH = 'model/u2net.pth'
13
- MODEL_URL = "https://huggingface.co/blurred-png/u2net-pth/resolve/main/u2net.pth"
14
-
15
- def download_model():
16
- if not os.path.exists(MODEL_PATH):
17
- print("Downloading model from Hugging Face...")
18
- os.makedirs(MODEL_DIR, exist_ok=True)
19
- r = requests.get(MODEL_URL, stream=True)
20
- with open(MODEL_PATH, "wb") as f:
21
- for chunk in r.iter_content(chunk_size=8192):
22
- if chunk:
23
- f.write(chunk)
24
- print("Download complete.")
25
-
26
- def load_model():
27
- download_model()
28
- net = U2NET(3, 1)
29
- net.load_state_dict(torch.load(MODEL_PATH, map_location="cpu", weights_only=False))
30
- net.eval()
31
- return net
32
-
33
- model = load_model()
34
-
35
- def preprocess(image):
36
- transform = transforms.Compose([
37
- transforms.Resize((320, 320)),
38
- transforms.ToTensor(),
39
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
40
- std=[0.229, 0.224, 0.225])
41
- ])
42
- return transform(image).unsqueeze(0)
43
-
44
- def postprocess(mask, original_size):
45
- mask = mask.squeeze().cpu().data.numpy()
46
- mask = (mask - mask.min()) / (mask.max() - mask.min())
47
- mask = Image.fromarray((mask * 255).astype(np.uint8)).resize(original_size, Image.BILINEAR)
48
- return mask
49
-
50
- def remove_background(image):
51
- original_size = image.size
52
- tensor = preprocess(image)
53
- with torch.no_grad():
54
- d = model(tensor)[0][:, 0, :, :]
55
- mask = postprocess(d, original_size)
56
- image = image.convert("RGBA")
57
- r, g, b, _ = image.split()
58
- merged = Image.merge("RGBA", (r, g, b, mask))
59
- return merged