Update u2net_utils.py
Browse files- u2net_utils.py +19 -12
u2net_utils.py
CHANGED
|
@@ -45,19 +45,26 @@ def postprocess(mask, original_size):
|
|
| 45 |
mask = Image.fromarray((mask * 255).astype(np.uint8)).resize(original_size, Image.BILINEAR)
|
| 46 |
return mask
|
| 47 |
|
| 48 |
-
def remove_background(image):
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
with torch.no_grad():
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
| 53 |
|
|
|
|
| 54 |
image = image.convert("RGBA")
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
new_data = []
|
| 59 |
-
for item, m in zip(datas, masks):
|
| 60 |
-
new_data.append((item[0], item[1], item[2], m))
|
| 61 |
|
| 62 |
-
|
| 63 |
-
return image
|
|
|
|
| 45 |
mask = Image.fromarray((mask * 255).astype(np.uint8)).resize(original_size, Image.BILINEAR)
|
| 46 |
return mask
|
| 47 |
|
| 48 |
+
def remove_background(image: Image.Image) -> Image.Image:
|
| 49 |
+
# Сохраняем оригинальный размер
|
| 50 |
+
original_size = image.size
|
| 51 |
+
# Преобразуем изображение в формат, подходящий для модели
|
| 52 |
+
transform = transforms.Compose([
|
| 53 |
+
transforms.Resize((320, 320)), # только для подачи в модель
|
| 54 |
+
transforms.ToTensor(),
|
| 55 |
+
])
|
| 56 |
+
image_tensor = transform(image).unsqueeze(0)
|
| 57 |
with torch.no_grad():
|
| 58 |
+
output = model(image_tensor)[0][0] # предположим, что это U2NET
|
| 59 |
+
mask = output.squeeze().cpu().numpy()
|
| 60 |
+
mask = (mask - mask.min()) / (mask.max() - mask.min()) # нормализация
|
| 61 |
+
# Маску надо вернуть к оригинальному размеру
|
| 62 |
+
mask = Image.fromarray((mask * 255).astype('uint8')).resize(original_size, Image.LANCZOS)
|
| 63 |
|
| 64 |
+
# Удаляем фон, сохраняя оригинальное изображение
|
| 65 |
image = image.convert("RGBA")
|
| 66 |
+
mask = mask.convert("L")
|
| 67 |
+
new_image = Image.new("RGBA", original_size)
|
| 68 |
+
new_image.paste(image, (0, 0), mask=mask)
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
+
return new_image
|
|
|