File size: 4,370 Bytes
1e3b872 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import torch
class AlphaChanelAdd:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "node"
CATEGORY = "image/alpha"
def node(self, images):
batch, height, width, channels = images.shape
if channels == 4:
return images
alpha = torch.ones((batch, height, width, 1))
return (torch.cat((images, alpha), dim=-1),)
class AlphaChanelAddByMask:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"mask": ("MASK",),
"method": (["default", "invert"],),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "node"
CATEGORY = "image/alpha"
def node(self, images, mask, method):
img_count, img_height, img_width = images[:, :, :, 0].shape
mask_count, mask_height, mask_width = mask.shape
if mask_width == 64 and mask_height == 64:
mask = torch.zeros((img_count, img_height, img_width))
else:
if img_height != mask_height or img_width != mask_width:
raise ValueError(
"[AlphaChanelByMask]: Size of images not equals size of mask. " +
"Images: [" + str(img_width) + ", " + str(img_height) + "] - " +
"Mask: [" + str(mask_width) + ", " + str(mask_height) + "]."
)
if img_count != mask_count:
mask = mask.expand((img_count, -1, -1))
if method == "default":
return (torch.stack([
torch.stack((
images[i, :, :, 0],
images[i, :, :, 1],
images[i, :, :, 2],
1. - mask[i]
), dim=-1) for i in range(len(images))
]),)
else:
return (torch.stack([
torch.stack((
images[i, :, :, 0],
images[i, :, :, 1],
images[i, :, :, 2],
mask[i]
), dim=-1) for i in range(len(images))
]),)
class AlphaChanelAsMask:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"method": (["default", "invert"],),
},
}
RETURN_TYPES = ("MASK",)
FUNCTION = "node"
CATEGORY = "image/alpha"
def node(self, images, method):
if images[0, 0, 0].shape[0] != 4:
raise ValueError("Alpha chanel not exist.")
if method == "default":
return (1.0 - images[0, :, :, 3],)
elif method == "invert":
return (images[0, :, :, 3],)
else:
raise ValueError("Unexpected method.")
class AlphaChanelRestore:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "node"
CATEGORY = "image/alpha"
def node(self, images):
batch, height, width, channels = images.shape
if channels != 4:
return images
tensor = images.clone().detach()
tensor[:, :, :, 3] = torch.ones((batch, height, width))
return (tensor,)
class AlphaChanelRemove:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "node"
CATEGORY = "image/alpha"
def node(self, images):
return (images[:, :, :, 0:3],)
NODE_CLASS_MAPPINGS = {
"AlphaChanelAdd": AlphaChanelAdd,
"AlphaChanelAddByMask": AlphaChanelAddByMask,
"AlphaChanelAsMask": AlphaChanelAsMask,
"AlphaChanelRestore": AlphaChanelRestore,
"AlphaChanelRemove": AlphaChanelRemove
}
|