3v324v23's picture
lfs
1e3b872
import cv2
import numpy as np
import torch
import torchvision.transforms.functional as F
from .Utils import radialspace_1D, radialspace_2D, cv2_layer
class ImageEffectsAdjustment:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"brightness": ("FLOAT", {
"default": 1.0,
"step": 0.01
}),
"contrast": ("FLOAT", {
"default": 1.0,
"step": 0.01
}),
"saturation": ("FLOAT", {
"default": 1.0,
"step": 0.01
}),
"hue": ("FLOAT", {
"default": 0.5,
"max": 1.0,
"step": 0.01
}),
"gamma": ("FLOAT", {
"default": 1.0,
"step": 0.01
}),
"sharpness": ("FLOAT", {
"default": 1.0,
"step": 0.01
}),
"red": ("FLOAT", {
"default": 1.0,
"step": 0.01
}),
"green": ("FLOAT", {
"default": 1.0,
"step": 0.01
}),
"blue": ("FLOAT", {
"default": 1.0,
"step": 0.01
}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "node"
CATEGORY = "image/effects"
def node(self, images, brightness, contrast, saturation, hue, gamma, sharpness, red, green, blue):
# noinspection PyUnboundLocalVariable
def apply(img):
rgba = False
if len(img[0, 0, :]) == 4:
a = img[:, :, 3].unsqueeze(2)
img = img[:, :, 0:3]
rgba = True
img = img.permute(2, 0, 1)
if brightness != 1.0:
img = F.adjust_brightness(img, brightness)
if contrast != 1.0:
img = F.adjust_contrast(img, contrast)
if saturation != 1.0:
img = F.adjust_saturation(img, saturation)
if hue != 0.5:
img = F.adjust_hue(img, hue - 0.5)
if gamma != 1.0:
img = F.adjust_gamma(img, gamma)
if sharpness != 1.0:
img = F.adjust_sharpness(img, sharpness)
img = img.permute(1, 2, 0)
r, g, b = torch.chunk(img, 3, dim=2)
r = torch.clamp(r * red, 0, 1)
g = torch.clamp(g * green, 0, 1)
b = torch.clamp(b * blue, 0, 1)
if rgba:
return torch.cat([r, g, b, a], dim=2)
else:
return torch.cat([r, g, b], dim=2)
return (torch.stack([
apply(images[i]) for i in range(len(images))
]),)
class ImageEffectsGrayscale:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "node"
CATEGORY = "image/effects"
def node(self, images):
def apply(image):
tensor = image.clone().detach()
grayscale_tensor = torch.mean(tensor, dim=2, keepdim=True)
return torch.cat([grayscale_tensor] * 3, dim=2)
return (torch.stack([
apply(images[i]) for i in range(len(images))
]),)
class ImageEffectsNegative:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "node"
CATEGORY = "image/effects"
def node(self, images):
tensor = images.clone().detach()
tensor[:, :, :, 0:3] = 1.0 - tensor[:, :, :, 0:3]
return (tensor,)
class ImageEffectsSepia:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "node"
CATEGORY = "image/effects"
def node(self, images):
tensor = images.clone().detach()
sepia_mask = torch.tensor([[0.393, 0.349, 0.272],
[0.769, 0.686, 0.534],
[0.189, 0.168, 0.131]])
tensor[:, :, :, 0:3] = torch.stack([
torch.matmul(tensor[i, :, :, 0:3], sepia_mask) for i in range(len(tensor))
])
return (tensor,)
class ImageEffectsLensZoomBurst:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"scale": ("FLOAT", {
"default": 1.5,
"min": 1.0,
"step": 0.01
}),
"samples": ("INT", {
"default": 100,
"min": 1,
}),
"position_x": ("FLOAT", {
"default": 0.5,
"max": 1.0,
"step": 0.01
}),
"position_y": ("FLOAT", {
"default": 0.5,
"max": 1.0,
"step": 0.01
}),
"rotation": ("FLOAT", {
"default": 0.0,
"min": 0.0,
"max": 360.0,
}),
"method": (["circle", "point"],),
"stabilization": (["true", "false"],),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "node"
CATEGORY = "image/effects/lens"
# noinspection PyUnresolvedReferences
def zoom_burst(
self,
image,
scale,
samples,
position,
rotation,
method,
stabilization,
):
if scale < 1.0:
raise ValueError("Parameter scale can't be smaller then initial image size.")
h, w = image.shape[:2]
x = np.arange(w)
y = np.arange(h)
xx, yy = np.meshgrid(x, y)
cx = int(w * position[0])
cy = int(h * position[1])
if method == "circle":
d = np.sqrt((xx - cx) ** 2 + (yy - cy) ** 2)
max_d = np.sqrt((w / 2) ** 2 + (h / 2) ** 2)
map_x_up = (xx - d * (scale - 1.0) / max_d * (xx - cx) / samples).astype(np.float32)
map_y_up = (yy - d * (scale - 1.0) / max_d * (yy - cy) / samples).astype(np.float32)
map_x_down = (xx + d * (scale - 1.0) / max_d * (xx - cx) / samples).astype(np.float32)
map_y_down = (yy + d * (scale - 1.0) / max_d * (yy - cy) / samples).astype(np.float32)
elif method == "point":
map_x_up = (xx - (xx - cx) * (scale - 1.0) / samples).astype(np.float32)
map_y_up = (yy - (yy - cy) * (scale - 1.0) / samples).astype(np.float32)
map_x_down = (xx + (xx - cx) * (scale - 1.0) / samples).astype(np.float32)
map_y_down = (yy + (yy - cy) * (scale - 1.0) / samples).astype(np.float32)
else:
raise ValueError("Unsupported method.")
if rotation > 0.0:
angle_step = rotation / samples
rm_up = cv2.getRotationMatrix2D((cx, cy), angle_step, 1)
rm_down = cv2.getRotationMatrix2D((cx, cy), -angle_step, 1)
else:
vibration_angle = 1.0
vibration_step = vibration_angle / samples
rm_up_even = cv2.getRotationMatrix2D((cx, cy), vibration_step, 1)
rm_down_even = cv2.getRotationMatrix2D((cx, cy), -vibration_step, 1)
rm_up_odd = cv2.getRotationMatrix2D((cx, cy), -vibration_step, 1)
rm_down_odd = cv2.getRotationMatrix2D((cx, cy), vibration_step, 1)
for i in range(samples):
if stabilization:
tmp_up = cv2.remap(image, map_x_up, map_y_up, cv2.INTER_LINEAR)
tmp_down = cv2.remap(image, map_x_down, map_y_down, cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
if rotation > 0.0:
tmp_up = cv2.warpAffine(tmp_up, rm_up, (w, h), borderMode=cv2.BORDER_REFLECT)
tmp_down = cv2.warpAffine(tmp_down, rm_down, (w, h), borderMode=cv2.BORDER_REFLECT)
else:
if i % 2 == 0:
tmp_up = cv2.warpAffine(tmp_up, rm_up_even, (w, h), borderMode=cv2.BORDER_REFLECT)
tmp_down = cv2.warpAffine(tmp_down, rm_down_even, (w, h), borderMode=cv2.BORDER_REFLECT)
else:
tmp_up = cv2.warpAffine(tmp_up, rm_up_odd, (w, h), borderMode=cv2.BORDER_REFLECT)
tmp_down = cv2.warpAffine(tmp_down, rm_down_odd, (w, h), borderMode=cv2.BORDER_REFLECT)
image = cv2.addWeighted(tmp_up, 0.5, tmp_down, 0.5, 0)
else:
tmp = cv2.remap(image, map_x_up, map_y_up, cv2.INTER_LINEAR)
if rotation > 0.0:
rm = cv2.getRotationMatrix2D((cx, cy), angle_step, 1)
tmp = cv2.warpAffine(tmp, rm, (w, h), borderMode=cv2.BORDER_REFLECT)
else:
if i % 2 == 0:
tmp = cv2.warpAffine(tmp, rm_up_even, (w, h), borderMode=cv2.BORDER_REFLECT)
else:
tmp = cv2.warpAffine(tmp, rm_up_odd, (w, h), borderMode=cv2.BORDER_REFLECT)
image = cv2.addWeighted(tmp, 0.5, image, 0.5, 0)
return image
def node(self, images, scale, samples, position_x, position_y, rotation, method, stabilization):
tensor = images.clone().detach()
return (cv2_layer(tensor, lambda x: self.zoom_burst(
x, scale, samples, (position_x, position_y), rotation, method, True if stabilization == "true" else False
)),)
class ImageEffectsLensChromaticAberration:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"shift": ("INT", {
"default": 10,
"step": 1,
}),
"method": (["reflect", "edge", "constant"],),
"shift_type": ("INT", {
"default": 1,
"min": 1,
"max": 4,
"step": 1,
}),
"mixing_type": ("INT", {
"default": 1,
"min": 1,
"max": 4,
"step": 1,
}),
"transpose": (["none", "rotate", "reflect"],),
"colors": (["rb", "rg", "gb"],),
"lens_curvy": ("FLOAT", {
"default": 1.0,
"max": 15.0,
"step": 0.1,
}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "node"
CATEGORY = "image/effects/lens"
def node(self, images, shift, method, shift_type, mixing_type, transpose, colors, lens_curvy):
# noinspection PyUnboundLocalVariable
def apply(image):
img = image.clone().detach()
if transpose == "rotate":
img = img.permute(1, 0, 2)
elif transpose == "none" or transpose == "reflect":
pass
else:
raise ValueError("Not existing reverse.")
r, g, b = img[:, :, 0:3].split(1, 2)
height, width = img[:, :, 0].shape
def get_space(size, min_val, max_val):
size += shift * 2
if min_val == max_val:
return torch.linspace(min_val, max_val, size)
else:
return radialspace_1D(size, lens_curvy, 1.0, min_val, max_val)
if shift_type == 1:
f_shifts = get_space(height, -shift, shift)
if transpose == "reflect":
t_shifts = get_space(width, -shift, shift)
elif shift_type == 2:
f_shifts = get_space(height, 0, shift)
f_shifts = torch.flip(f_shifts, dims=(0,))
if transpose == "reflect":
t_shifts = get_space(width, 0, shift)
t_shifts = torch.flip(t_shifts, dims=(0,))
elif shift_type == 3:
f_shifts = get_space(height, 0, shift)
if transpose == "reflect":
t_shifts = get_space(width, 0, shift)
elif shift_type == 4:
f_shifts = get_space(height, shift, shift)
if transpose == "reflect":
t_shifts = get_space(width, shift, shift)
else:
raise ValueError("Not existing shift_type.")
if mixing_type == 1:
f_shifts = f_shifts
s_shifts = -f_shifts
if transpose == "reflect":
t_shifts = t_shifts
c_shifts = -t_shifts
elif mixing_type == 2:
f_shifts = -f_shifts
s_shifts = f_shifts
if transpose == "reflect":
t_shifts = -t_shifts
c_shifts = t_shifts
elif mixing_type == 3:
f_shifts = f_shifts
s_shifts = f_shifts
if transpose == "reflect":
t_shifts = t_shifts
c_shifts = t_shifts
elif mixing_type == 4:
f_shifts = -f_shifts
s_shifts = -f_shifts
if transpose == "reflect":
t_shifts = -t_shifts
c_shifts = -t_shifts
else:
raise ValueError("Not existing mixing_type.")
if colors == "rb":
def cat(f_value, s_value):
return torch.cat([f_value, g, s_value], 2)
f = r
s = b
elif colors == "rg":
def cat(f_value, s_value):
return torch.cat([f_value, s_value, b], 2)
f = r
s = g
elif colors == "gb":
def cat(f_value, s_value):
return torch.cat([r, f_value, s_value], 2)
f = g
s = b
else:
raise ValueError("Not existing colors.")
f_pad = F.pad(f.squeeze(), [shift, shift], padding_mode=method).unsqueeze(2)
s_pad = F.pad(s.squeeze(), [shift, shift], padding_mode=method).unsqueeze(2)
f_shifted = torch.zeros_like(f_pad)
s_shifted = torch.zeros_like(s_pad)
for i in range(height + (shift * 2)):
f_shifted[i] = torch.roll(f_pad[i], shifts=int(f_shifts[i]), dims=0)
s_shifted[i] = torch.roll(s_pad[i], shifts=int(s_shifts[i]), dims=0)
if transpose == "reflect":
for i in range(width + (shift * 2)):
f_shifted[:, i] = torch.roll(f_shifted[:, i], shifts=int(t_shifts[i]), dims=0)
s_shifted[:, i] = torch.roll(s_shifted[:, i], shifts=int(c_shifts[i]), dims=0)
f_result = f_shifted[shift:-shift, shift:-shift, :]
s_result = s_shifted[shift:-shift, shift:-shift, :]
img[:, :, 0:3] = cat(f_result, s_result)
if transpose == "rotate":
img = img.permute(1, 0, 2)
elif transpose == "none" or transpose == "reflect":
pass
else:
raise ValueError("Not existing reverse.")
return img
return (torch.stack([
apply(images[i]) for i in range(len(images))
]),)
class ImageEffectsLensBokeh:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"blades_shape": ("INT", {
"default": 5,
"min": 3,
}),
"blades_radius": ("INT", {
"default": 10,
"min": 1,
}),
"blades_rotation": ("FLOAT", {
"default": 0.0,
"min": 0.0,
"max": 360.0,
}),
"blur_size": ("INT", {
"default": 10,
"min": 1,
"step": 2
}),
"blur_type": (["bilateral", "stack", "none"],),
"method": (["dilate", "filter"],),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "node"
CATEGORY = "image/effects/lens"
# noinspection PyUnresolvedReferences
def lens_blur(self, image, blades_shape, blades_radius, blades_rotation, method):
angles = np.linspace(0, 2 * np.pi, blades_shape + 1)[:-1] + blades_rotation * np.pi / 180
x = blades_radius * np.cos(angles) + blades_radius
y = blades_radius * np.sin(angles) + blades_radius
pts = np.stack([x, y], axis=1).astype(np.int32)
mask = np.zeros((blades_radius * 2 + 1, blades_radius * 2 + 1), np.uint8)
cv2.fillPoly(mask, [pts], 255)
gaussian_kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
if method == "dilate":
kernel = cv2.filter2D(mask, -1, gaussian_kernel)
result = cv2.dilate(image, kernel)
elif method == "filter":
height, width = image.shape[:2]
dilate_size = min(height, width) // 512
if dilate_size > 0:
image = cv2.dilate(image, np.ones((dilate_size, dilate_size), np.uint8))
kernel = mask.astype(np.float32) / np.sum(mask)
kernel = cv2.filter2D(kernel, -1, gaussian_kernel)
result = cv2.filter2D(image, -1, kernel)
else:
raise ValueError("Unsupported method.")
return result
def node(self, images, blades_shape, blades_radius, blades_rotation, blur_size, blur_type, method):
tensor = images.clone().detach()
blur_size -= 1
if blur_type == "bilateral":
tensor = cv2_layer(tensor, lambda x: cv2.bilateralFilter(x, blur_size, -100, 100))
elif blur_type == "stack":
tensor = cv2_layer(tensor, lambda x: cv2.stackBlur(x, (blur_size, blur_size)))
elif blur_type == "none":
pass
else:
raise ValueError("Unsupported blur type.")
return (cv2_layer(tensor, lambda x: self.lens_blur(
x, blades_shape, blades_radius, blades_rotation, method
)),)
class ImageEffectsLensOpticAxis:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"lens_shape": (["circle", "square", "rectangle", "corners"],),
"lens_edge": (["around", "symmetric"],),
"lens_curvy": ("FLOAT", {
"default": 4.0,
"max": 15.0,
"step": 0.1,
}),
"lens_zoom": ("FLOAT", {
"default": 2.0,
"step": 0.1,
}),
"lens_aperture": ("FLOAT", {
"default": 0.5,
"max": 10.0,
"step": 0.1,
}),
"blur_intensity": ("INT", {
"default": 30,
"min": 2,
"step": 2
}),
},
}
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "node"
CATEGORY = "image/effects/lens"
def node(self, images, lens_shape, lens_edge, lens_curvy, lens_zoom, lens_aperture, blur_intensity):
blur_intensity -= 1
lens_zoom += 1
height, width = images[0, :, :, 0].shape
if lens_edge == "around":
mask = radialspace_2D((height, width), lens_curvy, lens_zoom, lens_shape, 0.0, 1.0 + lens_curvy).unsqueeze(0).unsqueeze(3)
elif lens_edge == "symmetric":
if height != width:
new_height = new_width = max(height, width)
crop_top_bottom = (new_height - height) // 2
crop_left_right = (new_width - width) // 2
mask = radialspace_2D((new_height, new_width), lens_curvy, lens_zoom, lens_shape, 0.0, 1.0 + lens_curvy)[
crop_top_bottom:-crop_top_bottom or None,
crop_left_right:-crop_left_right or None
].unsqueeze(0).unsqueeze(3)
else:
mask = radialspace_2D((height, width), lens_curvy, lens_zoom, lens_shape, 0.0, 1.0 + lens_curvy).unsqueeze(0).unsqueeze(3)
else:
raise ValueError("Not existing lens_edge parameter.")
center_x = width // 2
center_y = height // 2
y_v, x_v = torch.meshgrid(torch.arange(height), torch.arange(width), indexing='ij')
dx = x_v - center_x
dy = y_v - center_y
distance = torch.sqrt(dx ** 2 + dy ** 2)
map_x = x_v + mask[0, :, :, 0] * dx / distance * (-lens_aperture * 100)
map_y = y_v + mask[0, :, :, 0] * dy / distance * (-lens_aperture * 100)
map_x = map_x.to(torch.float32).numpy()
map_y = map_y.to(torch.float32).numpy()
shifted_images = cv2_layer(images, lambda x: cv2.remap(x, map_x, map_y, cv2.INTER_LINEAR))
shifted_mask = cv2_layer(mask, lambda x: cv2.remap(x, map_x, map_y, cv2.INTER_LINEAR))
edited_images = cv2_layer(shifted_images, lambda x: cv2.stackBlur(x, (blur_intensity, blur_intensity)))
mask = torch.clamp(mask, 0.0, 1.0)
shifted_mask = torch.clamp(shifted_mask, 0.0, 1.0)
result = shifted_images * (1 - mask) + edited_images * mask
return result, shifted_mask
class ImageEffectsLensVignette:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"lens_shape": (["circle", "rectangle"],),
"lens_edge": (["around", "symmetric"],),
"lens_curvy": ("FLOAT", {
"default": 3.0,
"max": 15.0,
"step": 0.1,
}),
"lens_zoom": ("FLOAT", {
"default": 0.0,
"step": 0.1,
}),
"brightness": ("FLOAT", {
"default": 0.25,
"max": 1.0,
"step": 0.01
}),
"saturation": ("FLOAT", {
"default": 0.5,
"max": 1.0,
"step": 0.01
}),
},
}
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "node"
CATEGORY = "image/effects/lens"
def node(self, images, lens_shape, lens_edge, lens_curvy, lens_zoom, brightness, saturation):
tensor = images.clone().detach()
lens_zoom += 1
height, width = tensor[0, :, :, 0].shape
if lens_edge == "around":
mask = radialspace_2D((height, width), lens_curvy, lens_zoom, lens_shape).unsqueeze(0).unsqueeze(3)
elif lens_edge == "symmetric":
if height != width:
new_height = new_width = max(height, width)
crop_top_bottom = (new_height - height) // 2
crop_left_right = (new_width - width) // 2
mask = radialspace_2D((new_height, new_width), lens_curvy, lens_zoom, lens_shape)[
crop_top_bottom:-crop_top_bottom or None,
crop_left_right:-crop_left_right or None
].unsqueeze(0).unsqueeze(3)
else:
mask = radialspace_2D((height, width), lens_curvy, lens_zoom, lens_shape).unsqueeze(0).unsqueeze(3)
else:
raise ValueError("Not existing lens_edge parameter.")
tensor = tensor.permute(0, 3, 1, 2)
tensor[:, 0:3, :, :] = F.adjust_brightness(tensor[:, 0:3, :, :], brightness)
tensor[:, 0:3, :, :] = F.adjust_saturation(tensor[:, 0:3, :, :], saturation)
tensor = tensor.permute(0, 2, 3, 1)
result = images * (1 - mask) + tensor * mask
mask = mask.squeeze()
return result, mask
NODE_CLASS_MAPPINGS = {
"ImageEffectsAdjustment": ImageEffectsAdjustment,
"ImageEffectsGrayscale": ImageEffectsGrayscale,
"ImageEffectsNegative": ImageEffectsNegative,
"ImageEffectsSepia": ImageEffectsSepia,
"ImageEffectsLensZoomBurst": ImageEffectsLensZoomBurst,
"ImageEffectsLensChromaticAberration": ImageEffectsLensChromaticAberration,
"ImageEffectsLensBokeh": ImageEffectsLensBokeh,
"ImageEffectsLensOpticAxis": ImageEffectsLensOpticAxis,
"ImageEffectsLensVignette": ImageEffectsLensVignette
}