| |
|
| |
|
| |
|
| | import torch
|
| |
|
| |
|
| | class BatchRemoveFirstLast:
|
| | """
|
| | Takes an IMAGE batch and returns the same batch, except:
|
| | - removes the FIRST image (index 0)
|
| | - removes the LAST image (index B-1)
|
| |
|
| | Output = images[1:-1]
|
| |
|
| | Notes:
|
| | - If the batch has fewer than 3 images (B < 3), removing both ends would
|
| | produce an empty/invalid batch, so this node returns the original batch.
|
| | - If a single image comes in as [H, W, C], it is treated as a batch of 1.
|
| | """
|
| |
|
| | CATEGORY = "image/batch"
|
| | FUNCTION = "remove_first_last"
|
| |
|
| | RETURN_TYPES = ("IMAGE",)
|
| | RETURN_NAMES = ("images",)
|
| |
|
| | @classmethod
|
| | def INPUT_TYPES(cls):
|
| | return {"required": {"images": ("IMAGE",)}}
|
| |
|
| | def remove_first_last(self, images):
|
| | if not isinstance(images, torch.Tensor):
|
| |
|
| | return (images,)
|
| |
|
| |
|
| | if images.dim() == 3:
|
| | images = images.unsqueeze(0)
|
| | elif images.dim() != 4:
|
| |
|
| | return (images,)
|
| |
|
| | b = int(images.shape[0])
|
| |
|
| |
|
| | if b < 3:
|
| | return (images,)
|
| |
|
| | out = images[1:-1].clone()
|
| | return (out,)
|
| |
|
| |
|
| | NODE_CLASS_MAPPINGS = {
|
| | "BatchRemoveFirstLast": BatchRemoveFirstLast,
|
| | }
|
| |
|
| | NODE_DISPLAY_NAME_MAPPINGS = {
|
| | "BatchRemoveFirstLast": "Batch Remove First + Last",
|
| | }
|
| |
|