File size: 1,140 Bytes
efc1c32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from PIL import Image
import numpy as np
import torch

class SwarmRemBg:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE",),
            }
        }

    CATEGORY = "SwarmUI/images"
    RETURN_TYPES = ("IMAGE", "MASK",)
    FUNCTION = "rem"

    def rem(self, images):
        from rembg import remove

        output = []
        masks = []
        for image in images:
            i = 255.0 * image.cpu().numpy()
            img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
            img = img.convert("RGBA")
            img = remove(img, post_process_mask=True)
            output.append(np.array(img).astype(np.float32) / 255.0)
            if 'A' in img.getbands():
                mask = np.array(img.getchannel('A')).astype(np.float32) / 255.0
                masks.append(1. - mask)
            else:
                masks.append(np.zeros((64,64), dtype=np.float32))
        return (torch.from_numpy(np.array(output)), torch.from_numpy(np.array(masks)))

NODE_CLASS_MAPPINGS = {
    "SwarmRemBg": SwarmRemBg,
}