bluspater commited on
Commit
d49c6d4
·
verified ·
1 Parent(s): e3eab30

Update u2net_utils.py

Browse files
Files changed (1) hide show
  1. u2net_utils.py +12 -23
u2net_utils.py CHANGED
@@ -1,7 +1,6 @@
1
  import torch
2
  import torchvision.transforms as transforms
3
  from PIL import Image
4
- from io import BytesIO
5
  import numpy as np
6
  import os
7
  import requests
@@ -46,29 +45,19 @@ def postprocess(mask, original_size):
46
  mask = Image.fromarray((mask * 255).astype(np.uint8)).resize(original_size, Image.BILINEAR)
47
  return mask
48
 
49
- def remove_background(image: Image.Image) -> Image.Image:
50
- # Сохраняем оригинальный размер
51
- original_size = image.size
52
- # Преобразуем изображение в формат, подходящий для модели
53
- transform = transforms.Compose([
54
- transforms.Resize((320, 320)), # только для подачи в модель
55
- transforms.ToTensor(),
56
- ])
57
- image_tensor = transform(image).unsqueeze(0)
58
  with torch.no_grad():
59
- output = model(image_tensor)[0][0] # предположим, что это U2NET
60
- mask = output.squeeze().cpu().numpy()
61
- mask = (mask - mask.min()) / (mask.max() - mask.min()) # нормализация
62
- # Маску надо вернуть к оригинальному размеру
63
- mask = Image.fromarray((mask * 255).astype('uint8')).resize(original_size, Image.LANCZOS)
64
 
65
- # Удаляем фон, сохраняя оригинальное изображение
66
  image = image.convert("RGBA")
67
- mask = mask.convert("L")
68
- new_image = Image.new("RGBA", original_size)
69
- new_image.paste(image, (0, 0), mask=mask)
 
 
 
70
 
71
- buffer = BytesIO()
72
- new_image.save(buffer, format="PNG", optimize=False, compress_level=1)
73
- buffer.seek(0)
74
- return buffer
 
1
  import torch
2
  import torchvision.transforms as transforms
3
  from PIL import Image
 
4
  import numpy as np
5
  import os
6
  import requests
 
45
  mask = Image.fromarray((mask * 255).astype(np.uint8)).resize(original_size, Image.BILINEAR)
46
  return mask
47
 
48
+ def remove_background(image):
49
+ input_tensor = preprocess(image)
 
 
 
 
 
 
 
50
  with torch.no_grad():
51
+ d1, *_ = model(input_tensor)
52
+ mask = postprocess(d1, image.size)
 
 
 
53
 
 
54
  image = image.convert("RGBA")
55
+ datas = image.getdata()
56
+ masks = mask.getdata()
57
+
58
+ new_data = []
59
+ for item, m in zip(datas, masks):
60
+ new_data.append((item[0], item[1], item[2], m))
61
 
62
+ image.putdata(new_data)
63
+ return image