bluspater commited on
Commit
08108ec
·
verified ·
1 Parent(s): 64ed02f

Update u2net_utils.py

Browse files
Files changed (1) hide show
  1. u2net_utils.py +19 -12
u2net_utils.py CHANGED
@@ -45,19 +45,26 @@ def postprocess(mask, original_size):
45
  mask = Image.fromarray((mask * 255).astype(np.uint8)).resize(original_size, Image.BILINEAR)
46
  return mask
47
 
48
- def remove_background(image):
49
- input_tensor = preprocess(image)
 
 
 
 
 
 
 
50
  with torch.no_grad():
51
- d1, *_ = model(input_tensor)
52
- mask = postprocess(d1, image.size)
 
 
 
53
 
 
54
  image = image.convert("RGBA")
55
- datas = image.getdata()
56
- masks = mask.getdata()
57
-
58
- new_data = []
59
- for item, m in zip(datas, masks):
60
- new_data.append((item[0], item[1], item[2], m))
61
 
62
- image.putdata(new_data)
63
- return image
 
45
  mask = Image.fromarray((mask * 255).astype(np.uint8)).resize(original_size, Image.BILINEAR)
46
  return mask
47
 
48
+ def remove_background(image: Image.Image) -> Image.Image:
49
+ # Сохраняем оригинальный размер
50
+ original_size = image.size
51
+ # Преобразуем изображение в формат, подходящий для модели
52
+ transform = transforms.Compose([
53
+ transforms.Resize((320, 320)), # только для подачи в модель
54
+ transforms.ToTensor(),
55
+ ])
56
+ image_tensor = transform(image).unsqueeze(0)
57
  with torch.no_grad():
58
+ output = model(image_tensor)[0][0] # предположим, что это U2NET
59
+ mask = output.squeeze().cpu().numpy()
60
+ mask = (mask - mask.min()) / (mask.max() - mask.min()) # нормализация
61
+ # Маску надо вернуть к оригинальному размеру
62
+ mask = Image.fromarray((mask * 255).astype('uint8')).resize(original_size, Image.LANCZOS)
63
 
64
+ # Удаляем фон, сохраняя оригинальное изображение
65
  image = image.convert("RGBA")
66
+ mask = mask.convert("L")
67
+ new_image = Image.new("RGBA", original_size)
68
+ new_image.paste(image, (0, 0), mask=mask)
 
 
 
69
 
70
+ return new_image