saliacoel commited on
Commit
166476b
·
verified ·
1 Parent(s): 2dc80e2

Upload Get_Batch_Range_Start_To_End.py

Browse files
Files changed (1) hide show
  1. Get_Batch_Range_Start_To_End.py +86 -0
Get_Batch_Range_Start_To_End.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Save as: ComfyUI/custom_nodes/batch_slice_start_end.py
2
+ # Restart ComfyUI after saving.
3
+
4
+ import torch
5
+
6
+
7
+ class Get_Batch_Range_Start_To_End:
8
+ """
9
+ Inputs:
10
+ - start_id (INT)
11
+ - end_id (INT)
12
+ - images (IMAGE batch, typically torch.Tensor [B, H, W, C])
13
+
14
+ Outputs:
15
+ - sliced_images (IMAGE batch)
16
+ - status (STRING): "ok" or an error message
17
+ - count (INT): number of images in the *input* batch
18
+
19
+ Behavior:
20
+ - Returns images from start_id to end_id (inclusive).
21
+ - If invalid / impossible (out of range, start>end, empty batch, etc.),
22
+ returns the original input batch unchanged, plus an error message.
23
+ """
24
+
25
+ CATEGORY = "image/batch"
26
+ FUNCTION = "slice_batch"
27
+
28
+ RETURN_TYPES = ("IMAGE", "STRING", "INT")
29
+ RETURN_NAMES = ("images", "status", "count")
30
+
31
+ @classmethod
32
+ def INPUT_TYPES(cls):
33
+ return {
34
+ "required": {
35
+ "start_id": ("INT", {"default": 0, "min": 0, "max": 1_000_000, "step": 1}),
36
+ "end_id": ("INT", {"default": 0, "min": 0, "max": 1_000_000, "step": 1}),
37
+ "images": ("IMAGE",),
38
+ }
39
+ }
40
+
41
+ def slice_batch(self, start_id, end_id, images):
42
+ # Validate tensor
43
+ if not isinstance(images, torch.Tensor):
44
+ # Can't reliably "return original" if it's not a tensor, but try anyway.
45
+ return (images, "error: images is not a torch.Tensor", 0)
46
+
47
+ # Normalize to batched shape for safety
48
+ original = images
49
+ if images.dim() == 3:
50
+ images = images.unsqueeze(0) # [1, H, W, C]
51
+ elif images.dim() != 4:
52
+ # Return original unchanged
53
+ count = int(images.shape[0]) if images.dim() > 0 else 0
54
+ return (original, f"error: expected IMAGE with 3 or 4 dims, got {tuple(images.shape)}", count)
55
+
56
+ b = int(images.shape[0]) # input batch count
57
+
58
+ if b <= 0:
59
+ return (images, "error: empty batch (B=0)", 0)
60
+
61
+ # Validate indices (inclusive slicing)
62
+ if start_id > end_id:
63
+ return (images, f"error: start_id > end_id ({start_id} > {end_id})", b)
64
+
65
+ if start_id < 0 or end_id < 0:
66
+ return (images, f"error: negative index not allowed (start_id={start_id}, end_id={end_id})", b)
67
+
68
+ if start_id >= b or end_id >= b:
69
+ return (
70
+ images,
71
+ f"error: out of range (start_id={start_id}, end_id={end_id}, batch_size={b})",
72
+ b,
73
+ )
74
+
75
+ # Slice inclusive: [start_id, end_id]
76
+ sliced = images[start_id : end_id + 1].clone()
77
+ return (sliced, "ok", b)
78
+
79
+
80
+ NODE_CLASS_MAPPINGS = {
81
+ "Get_Batch_Range_Start_To_End": Get_Batch_Range_Start_To_End,
82
+ }
83
+
84
+ NODE_DISPLAY_NAME_MAPPINGS = {
85
+ "Get_Batch_Range_Start_To_End": "Get Batch from Batch (From Start ID to End ID)",
86
+ }