saliacoel commited on
Commit
f2ad0fa
·
verified ·
1 Parent(s): 3fb7a98

Upload BatchRaycast_2d.py

Browse files
Files changed (1) hide show
  1. BatchRaycast_2d.py +151 -0
BatchRaycast_2d.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def bresenham_line(x0: int, y0: int, x1: int, y1: int):
4
+ """
5
+ Integer Bresenham line algorithm.
6
+ Returns two Python lists: xs, ys (same length).
7
+ """
8
+ xs = []
9
+ ys = []
10
+
11
+ dx = abs(x1 - x0)
12
+ sx = 1 if x0 < x1 else -1
13
+ dy = -abs(y1 - y0)
14
+ sy = 1 if y0 < y1 else -1
15
+ err = dx + dy # error value e_xy
16
+
17
+ x, y = x0, y0
18
+ while True:
19
+ xs.append(x)
20
+ ys.append(y)
21
+ if x == x1 and y == y1:
22
+ break
23
+ e2 = 2 * err
24
+ if e2 >= dy:
25
+ err += dy
26
+ x += sx
27
+ if e2 <= dx:
28
+ err += dx
29
+ y += sy
30
+
31
+ return xs, ys
32
+
33
+
34
+ class BatchRaycast_2D:
35
+ """
36
+ Returns the first image in the batch where the START->END line
37
+ is completely "white enough" according to the chosen mode/threshold.
38
+ """
39
+
40
+ @classmethod
41
+ def INPUT_TYPES(cls):
42
+ return {
43
+ "required": {
44
+ "images": ("IMAGE",),
45
+
46
+ # "White" detection behavior:
47
+ # - max_channel: max(R,G,B) >= threshold (VERY tolerant; rgb(0,1,0) passes)
48
+ # - all_channels: min(R,G,B) >= threshold (strict white)
49
+ # - luminance: dot([0.2126, 0.7152, 0.0722], RGB) >= threshold
50
+ # - green_only: G >= threshold
51
+ "white_mode": (["max_channel", "all_channels", "luminance", "green_only"], {"default": "max_channel"}),
52
+
53
+ # For ComfyUI IMAGE tensors this is typically 0..1 float.
54
+ # Example: threshold=0.98 means "near 1.0"
55
+ "threshold": ("FLOAT", {"default": 0.98, "min": 0.0, "max": 1.0, "step": 0.005}),
56
+
57
+ # What to do if no image matches:
58
+ "fallback": (["return_first", "return_last"], {"default": "return_last"}),
59
+ },
60
+ "optional": {
61
+ # Keep your requested defaults, but allow override if needed.
62
+ "start_x": ("INT", {"default": 0, "min": -999999, "max": 999999}),
63
+ "start_y": ("INT", {"default": 386, "min": -999999, "max": 999999}),
64
+ "end_x": ("INT", {"default": 330, "min": -999999, "max": 999999}),
65
+ "end_y": ("INT", {"default": 385, "min": -999999, "max": 999999}),
66
+ }
67
+ }
68
+
69
+ RETURN_TYPES = ("IMAGE", "INT")
70
+ RETURN_NAMES = ("image", "index")
71
+ FUNCTION = "pick"
72
+ CATEGORY = "image/filter"
73
+
74
+ def pick(
75
+ self,
76
+ images,
77
+ white_mode="max_channel",
78
+ threshold=0.98,
79
+ fallback="return_last",
80
+ start_x=0,
81
+ start_y=386,
82
+ end_x=330,
83
+ end_y=385,
84
+ ):
85
+ # images: torch tensor [B,H,W,C], float in 0..1 (typical in ComfyUI)
86
+ if not isinstance(images, torch.Tensor):
87
+ raise TypeError("images must be a torch.Tensor (ComfyUI IMAGE type).")
88
+
89
+ if images.ndim != 4 or images.shape[-1] < 3:
90
+ raise ValueError(f"Expected images shape [B,H,W,C>=3], got {tuple(images.shape)}")
91
+
92
+ B, H, W, C = images.shape
93
+
94
+ # Build the integer pixel coordinates along the line
95
+ xs_list, ys_list = bresenham_line(int(start_x), int(start_y), int(end_x), int(end_y))
96
+
97
+ # Convert to tensors on same device
98
+ device = images.device
99
+ xs = torch.tensor(xs_list, device=device, dtype=torch.long)
100
+ ys = torch.tensor(ys_list, device=device, dtype=torch.long)
101
+
102
+ # If the line goes out of bounds, we treat it as "no match" for safety.
103
+ if xs.min().item() < 0 or ys.min().item() < 0 or xs.max().item() >= W or ys.max().item() >= H:
104
+ # fallback output
105
+ idx = 0 if fallback == "return_first" else max(B - 1, 0)
106
+ return (images[idx:idx+1], int(-1))
107
+
108
+ # Sample pixels along the line for every image in the batch: shape [B, N, 3]
109
+ pixels = images[:, ys, xs, :3]
110
+
111
+ if white_mode == "max_channel":
112
+ # VERY tolerant: any channel close to 1 counts as "white"
113
+ white_mask = pixels.max(dim=-1).values >= threshold
114
+
115
+ elif white_mode == "all_channels":
116
+ # Strict white: all channels must be close to 1
117
+ white_mask = pixels.min(dim=-1).values >= threshold
118
+
119
+ elif white_mode == "green_only":
120
+ # Only green channel must be high
121
+ white_mask = pixels[..., 1] >= threshold
122
+
123
+ elif white_mode == "luminance":
124
+ # Perceived brightness (Rec.709-ish)
125
+ weights = torch.tensor([0.2126, 0.7152, 0.0722], device=device, dtype=pixels.dtype)
126
+ lum = (pixels * weights).sum(dim=-1)
127
+ white_mask = lum >= threshold
128
+
129
+ else:
130
+ raise ValueError(f"Unknown white_mode: {white_mode}")
131
+
132
+ # Line is "completely white" if ALL pixels along line are white
133
+ line_is_white = white_mask.all(dim=1) # shape [B]
134
+
135
+ found = torch.nonzero(line_is_white, as_tuple=False).flatten()
136
+ if found.numel() > 0:
137
+ idx = int(found[0].item())
138
+ return (images[idx:idx+1], idx)
139
+
140
+ # No match: fallback
141
+ idx = 0 if fallback == "return_first" else max(B - 1, 0)
142
+ return (images[idx:idx+1], int(-1))
143
+
144
+
145
+ NODE_CLASS_MAPPINGS = {
146
+ "BatchRaycast_2D": BatchRaycast_2D
147
+ }
148
+
149
+ NODE_DISPLAY_NAME_MAPPINGS = {
150
+ "BatchRaycast_2D": "BatchRaycast_2D"
151
+ }