3v324v23's picture
lfs
1e3b872
from typing import Dict, Any, List, Optional, Tuple
from PIL import ImageFont, Image, ImageDraw
from torch import Tensor
import matplotlib.font_manager as fm
from custom_nodes.Comfy_KepListStuff.utils import tensor2pil, pil2tensor
# Hack: string type that is always equal in not equal comparisons
class AnyType(str):
def __ne__(self, __value: object) -> bool:
return False
# Our any instance wants to be a wildcard string
ANY = AnyType("*")
class StackImages:
def __init__(self) -> None:
pass
@classmethod
def INPUT_TYPES(s) -> Dict[str, Dict[str, Any]]:
return {
"required": {
"images": ("IMAGE",),
"splits": ("INT", {"forceInput": True, "min": 1}),
"stack_mode": (["horizontal", "vertical"], {"default": "horizontal"}),
"batch_stack_mode": (["horizontal", "vertical"], {"default": "horizontal"}),
},
"optional": {
"horizontal_labels": (ANY,{}),
"vertical_labels": (ANY,{}),
}
}
RELOAD_INST = True
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("Image",)
INPUT_IS_LIST = (True,)
OUTPUT_IS_LIST = (False,)
OUTPUT_NODE = True
FUNCTION = "stack_images"
CATEGORY = "List Stuff"
def stack_images(
self,
images: List[Tensor],
splits: List[int],
stack_mode: List[str],
batch_stack_mode: List[str],
horizontal_labels: Optional[List[str]] = None,
vertical_labels: Optional[List[str]] = None,
) -> Tuple[Tensor]:
if len(stack_mode) != 1:
raise Exception("Only single stack mode supported.")
if len(batch_stack_mode) != 1:
raise Exception("Only single batch stack mode supported.")
stack_direction = stack_mode[0]
batch_stack_direction = batch_stack_mode[0]
if len(splits) == 1:
splits = splits * (int(len(images) / splits[0]))
if sum(splits) != len(images):
splits.append(len(images) - sum(splits))
else:
if sum(splits) != len(images):
raise Exception("Sum of splits must equal number of images.")
batches = images
batch_size = len(batches[0])
image_h, image_w, _ = batches[0][0].size()
if batch_stack_direction == "horizontal":
batch_h = image_h
# stack horizontally
batch_w = image_w * batch_size
else:
# stack vertically
batch_h = image_h * batch_size
batch_w = image_w
if stack_direction == "horizontal":
full_w = batch_w * len(splits)
full_h = batch_h * max(splits)
else:
full_w = batch_w * max(splits)
full_h = batch_h * len(splits)
y_label_offset = 0
has_horizontal_labels = False
if horizontal_labels is not None:
horizontal_labels = [str(lbl) for lbl in horizontal_labels]
if stack_direction == "horizontal":
if len(horizontal_labels) != len(splits):
raise Exception("Number of horizontal labels must match number of splits.")
else:
if len(horizontal_labels) != max(splits):
raise Exception("Number of horizontal labels must match maximum split size.")
full_h += 60
y_label_offset = 60
has_horizontal_labels = True
x_label_offset = 0
has_vertical_labels = False
if vertical_labels is not None:
vertical_labels = [str(lbl) for lbl in vertical_labels]
if stack_direction == "horizontal":
if len(vertical_labels) != max(splits):
raise Exception("Number of vertical labels must match maximum split size.")
else:
if len(vertical_labels) != len(splits):
raise Exception("Number of vertical labels must match number of splits.")
full_w += 60
x_label_offset = 60
has_vertical_labels = True
full_image = Image.new("RGB", (full_w, full_h))
batch_idx = 0
if has_horizontal_labels:
assert horizontal_labels is not None
font = ImageFont.truetype(fm.findfont(fm.FontProperties()), 60)
for label_idx, label in enumerate(horizontal_labels):
x_offset = (batch_w * label_idx) + x_label_offset
draw = ImageDraw.Draw(full_image)
draw.rectangle((x_offset, 0, x_offset + batch_w, 60), fill="#ffffff")
draw.text((x_offset + (batch_w / 2), 0), label, fill="red", font=font)
if has_vertical_labels:
assert vertical_labels is not None
font = ImageFont.truetype(fm.findfont(fm.FontProperties()), 60)
for label_idx, label in enumerate(vertical_labels):
y_offset = (batch_h * label_idx) + y_label_offset
draw = ImageDraw.Draw(full_image)
draw.rectangle((0, y_offset, 60, y_offset + batch_h), fill="#ffffff")
draw.text((0, y_offset + (batch_h / 2)), label, fill="red", font=font)
for split_idx, split in enumerate(splits):
for idx_in_split in range(split):
batch_img = Image.new("RGB", (batch_w, batch_h))
batch = batches[batch_idx + idx_in_split]
if batch_stack_direction == "horizontal":
for img_idx, img in enumerate(batch):
x_offset = image_w * img_idx
batch_img.paste(tensor2pil(img), (x_offset, 0))
else:
for img_idx, img in enumerate(batch):
y_offset = image_h * img_idx
batch_img.paste(tensor2pil(img), (0, y_offset))
if stack_direction == "horizontal":
x_offset = batch_w * split_idx + x_label_offset
y_offset = batch_h * idx_in_split + y_label_offset
else:
x_offset = batch_w * idx_in_split + x_label_offset
y_offset = batch_h * split_idx + y_label_offset
full_image.paste(batch_img, (x_offset, y_offset))
batch_idx += split
return (pil2tensor(full_image),)