MyCustomNodes / BatchRaycast_2d.py
saliacoel's picture
Upload BatchRaycast_2d.py
f2ad0fa verified
import torch
def bresenham_line(x0: int, y0: int, x1: int, y1: int):
"""
Integer Bresenham line algorithm.
Returns two Python lists: xs, ys (same length).
"""
xs = []
ys = []
dx = abs(x1 - x0)
sx = 1 if x0 < x1 else -1
dy = -abs(y1 - y0)
sy = 1 if y0 < y1 else -1
err = dx + dy # error value e_xy
x, y = x0, y0
while True:
xs.append(x)
ys.append(y)
if x == x1 and y == y1:
break
e2 = 2 * err
if e2 >= dy:
err += dy
x += sx
if e2 <= dx:
err += dx
y += sy
return xs, ys
class BatchRaycast_2D:
"""
Returns the first image in the batch where the START->END line
is completely "white enough" according to the chosen mode/threshold.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
# "White" detection behavior:
# - max_channel: max(R,G,B) >= threshold (VERY tolerant; rgb(0,1,0) passes)
# - all_channels: min(R,G,B) >= threshold (strict white)
# - luminance: dot([0.2126, 0.7152, 0.0722], RGB) >= threshold
# - green_only: G >= threshold
"white_mode": (["max_channel", "all_channels", "luminance", "green_only"], {"default": "max_channel"}),
# For ComfyUI IMAGE tensors this is typically 0..1 float.
# Example: threshold=0.98 means "near 1.0"
"threshold": ("FLOAT", {"default": 0.98, "min": 0.0, "max": 1.0, "step": 0.005}),
# What to do if no image matches:
"fallback": (["return_first", "return_last"], {"default": "return_last"}),
},
"optional": {
# Keep your requested defaults, but allow override if needed.
"start_x": ("INT", {"default": 0, "min": -999999, "max": 999999}),
"start_y": ("INT", {"default": 386, "min": -999999, "max": 999999}),
"end_x": ("INT", {"default": 330, "min": -999999, "max": 999999}),
"end_y": ("INT", {"default": 385, "min": -999999, "max": 999999}),
}
}
RETURN_TYPES = ("IMAGE", "INT")
RETURN_NAMES = ("image", "index")
FUNCTION = "pick"
CATEGORY = "image/filter"
def pick(
self,
images,
white_mode="max_channel",
threshold=0.98,
fallback="return_last",
start_x=0,
start_y=386,
end_x=330,
end_y=385,
):
# images: torch tensor [B,H,W,C], float in 0..1 (typical in ComfyUI)
if not isinstance(images, torch.Tensor):
raise TypeError("images must be a torch.Tensor (ComfyUI IMAGE type).")
if images.ndim != 4 or images.shape[-1] < 3:
raise ValueError(f"Expected images shape [B,H,W,C>=3], got {tuple(images.shape)}")
B, H, W, C = images.shape
# Build the integer pixel coordinates along the line
xs_list, ys_list = bresenham_line(int(start_x), int(start_y), int(end_x), int(end_y))
# Convert to tensors on same device
device = images.device
xs = torch.tensor(xs_list, device=device, dtype=torch.long)
ys = torch.tensor(ys_list, device=device, dtype=torch.long)
# If the line goes out of bounds, we treat it as "no match" for safety.
if xs.min().item() < 0 or ys.min().item() < 0 or xs.max().item() >= W or ys.max().item() >= H:
# fallback output
idx = 0 if fallback == "return_first" else max(B - 1, 0)
return (images[idx:idx+1], int(-1))
# Sample pixels along the line for every image in the batch: shape [B, N, 3]
pixels = images[:, ys, xs, :3]
if white_mode == "max_channel":
# VERY tolerant: any channel close to 1 counts as "white"
white_mask = pixels.max(dim=-1).values >= threshold
elif white_mode == "all_channels":
# Strict white: all channels must be close to 1
white_mask = pixels.min(dim=-1).values >= threshold
elif white_mode == "green_only":
# Only green channel must be high
white_mask = pixels[..., 1] >= threshold
elif white_mode == "luminance":
# Perceived brightness (Rec.709-ish)
weights = torch.tensor([0.2126, 0.7152, 0.0722], device=device, dtype=pixels.dtype)
lum = (pixels * weights).sum(dim=-1)
white_mask = lum >= threshold
else:
raise ValueError(f"Unknown white_mode: {white_mode}")
# Line is "completely white" if ALL pixels along line are white
line_is_white = white_mask.all(dim=1) # shape [B]
found = torch.nonzero(line_is_white, as_tuple=False).flatten()
if found.numel() > 0:
idx = int(found[0].item())
return (images[idx:idx+1], idx)
# No match: fallback
idx = 0 if fallback == "return_first" else max(B - 1, 0)
return (images[idx:idx+1], int(-1))
NODE_CLASS_MAPPINGS = {
"BatchRaycast_2D": BatchRaycast_2D
}
NODE_DISPLAY_NAME_MAPPINGS = {
"BatchRaycast_2D": "BatchRaycast_2D"
}