| |
|
| |
|
| |
|
| | import torch
|
| |
|
| |
|
| | class Custom_Batch_Output:
|
| | """
|
| | Input:
|
| | - images (IMAGE batch, typically torch.Tensor [B, H, W, C])
|
| |
|
| | Outputs:
|
| | - Batch_Up: [ ID 7 ] + [ IDs 9..25 ] + [ IDs 27..31 ] + [ IDs 33..36 ]
|
| | - Rife_x3: [ ID 4 ] + [ ID 37 ] (2-image batch)
|
| |
|
| | Indexing is 0-based and ranges are inclusive (e.g., 9..25 includes both 9 and 25).
|
| |
|
| | Safety behavior:
|
| | - If the input batch is too small (needs at least indices up to 37 => B >= 38),
|
| | or input is not a proper IMAGE tensor, the node returns the original input batch
|
| | for BOTH outputs.
|
| | """
|
| |
|
| | CATEGORY = "image/batch"
|
| | FUNCTION = "make_special_batches"
|
| |
|
| | RETURN_TYPES = ("IMAGE", "IMAGE")
|
| | RETURN_NAMES = ("Batch_Up", "Rife_x3")
|
| |
|
| | @classmethod
|
| | def INPUT_TYPES(cls):
|
| | return {"required": {"images": ("IMAGE",)}}
|
| |
|
| | @staticmethod
|
| | def _normalize_to_batch(images: torch.Tensor) -> torch.Tensor:
|
| |
|
| | if images.dim() == 3:
|
| | return images.unsqueeze(0)
|
| | return images
|
| |
|
| | def make_special_batches(self, images):
|
| |
|
| | if not isinstance(images, torch.Tensor):
|
| | return (images, images)
|
| |
|
| | images = self._normalize_to_batch(images)
|
| |
|
| |
|
| | if images.dim() != 4:
|
| | return (images, images)
|
| |
|
| | b = int(images.shape[0])
|
| |
|
| |
|
| | if b < 38:
|
| | return (images, images)
|
| |
|
| |
|
| | batch_up_indices = (
|
| | [7]
|
| | + list(range(9, 26))
|
| | + list(range(27, 32))
|
| | + list(range(33, 37))
|
| | )
|
| |
|
| |
|
| | rife_x3_indices = [4, 37]
|
| |
|
| |
|
| | device = images.device
|
| | idx_up = torch.tensor(batch_up_indices, dtype=torch.long, device=device)
|
| | idx_rife = torch.tensor(rife_x3_indices, dtype=torch.long, device=device)
|
| |
|
| | batch_up = torch.index_select(images, 0, idx_up).clone()
|
| | rife_x3 = torch.index_select(images, 0, idx_rife).clone()
|
| |
|
| | return (batch_up, rife_x3)
|
| |
|
| |
|
| | NODE_CLASS_MAPPINGS = {
|
| | "Custom_Batch_Output": Custom_Batch_Output,
|
| | }
|
| |
|
| | NODE_DISPLAY_NAME_MAPPINGS = {
|
| | "Custom_Batch_Output": "Custom_Batch_Output",
|
| | }
|
| |
|