Update u2net_utils.py
Browse files- u2net_utils.py +12 -23
u2net_utils.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
import torch
|
| 2 |
import torchvision.transforms as transforms
|
| 3 |
from PIL import Image
|
| 4 |
-
from io import BytesIO
|
| 5 |
import numpy as np
|
| 6 |
import os
|
| 7 |
import requests
|
|
@@ -46,29 +45,19 @@ def postprocess(mask, original_size):
|
|
| 46 |
mask = Image.fromarray((mask * 255).astype(np.uint8)).resize(original_size, Image.BILINEAR)
|
| 47 |
return mask
|
| 48 |
|
| 49 |
-
def remove_background(image
|
| 50 |
-
|
| 51 |
-
original_size = image.size
|
| 52 |
-
# Преобразуем изображение в формат, подходящий для модели
|
| 53 |
-
transform = transforms.Compose([
|
| 54 |
-
transforms.Resize((320, 320)), # только для подачи в модель
|
| 55 |
-
transforms.ToTensor(),
|
| 56 |
-
])
|
| 57 |
-
image_tensor = transform(image).unsqueeze(0)
|
| 58 |
with torch.no_grad():
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
mask = (mask - mask.min()) / (mask.max() - mask.min()) # нормализация
|
| 62 |
-
# Маску надо вернуть к оригинальному размеру
|
| 63 |
-
mask = Image.fromarray((mask * 255).astype('uint8')).resize(original_size, Image.LANCZOS)
|
| 64 |
|
| 65 |
-
# Удаляем фон, сохраняя оригинальное изображение
|
| 66 |
image = image.convert("RGBA")
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
buffer.seek(0)
|
| 74 |
-
return buffer
|
|
|
|
| 1 |
import torch
|
| 2 |
import torchvision.transforms as transforms
|
| 3 |
from PIL import Image
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
import os
|
| 6 |
import requests
|
|
|
|
| 45 |
mask = Image.fromarray((mask * 255).astype(np.uint8)).resize(original_size, Image.BILINEAR)
|
| 46 |
return mask
|
| 47 |
|
| 48 |
+
def remove_background(image):
|
| 49 |
+
input_tensor = preprocess(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
with torch.no_grad():
|
| 51 |
+
d1, *_ = model(input_tensor)
|
| 52 |
+
mask = postprocess(d1, image.size)
|
|
|
|
|
|
|
|
|
|
| 53 |
|
|
|
|
| 54 |
image = image.convert("RGBA")
|
| 55 |
+
datas = image.getdata()
|
| 56 |
+
masks = mask.getdata()
|
| 57 |
+
|
| 58 |
+
new_data = []
|
| 59 |
+
for item, m in zip(datas, masks):
|
| 60 |
+
new_data.append((item[0], item[1], item[2], m))
|
| 61 |
|
| 62 |
+
image.putdata(new_data)
|
| 63 |
+
return image
|
|
|
|
|
|