File size: 5,408 Bytes
f2ad0fa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | import torch
def bresenham_line(x0: int, y0: int, x1: int, y1: int):
"""
Integer Bresenham line algorithm.
Returns two Python lists: xs, ys (same length).
"""
xs = []
ys = []
dx = abs(x1 - x0)
sx = 1 if x0 < x1 else -1
dy = -abs(y1 - y0)
sy = 1 if y0 < y1 else -1
err = dx + dy # error value e_xy
x, y = x0, y0
while True:
xs.append(x)
ys.append(y)
if x == x1 and y == y1:
break
e2 = 2 * err
if e2 >= dy:
err += dy
x += sx
if e2 <= dx:
err += dx
y += sy
return xs, ys
class BatchRaycast_2D:
"""
Returns the first image in the batch where the START->END line
is completely "white enough" according to the chosen mode/threshold.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
# "White" detection behavior:
# - max_channel: max(R,G,B) >= threshold (VERY tolerant; rgb(0,1,0) passes)
# - all_channels: min(R,G,B) >= threshold (strict white)
# - luminance: dot([0.2126, 0.7152, 0.0722], RGB) >= threshold
# - green_only: G >= threshold
"white_mode": (["max_channel", "all_channels", "luminance", "green_only"], {"default": "max_channel"}),
# For ComfyUI IMAGE tensors this is typically 0..1 float.
# Example: threshold=0.98 means "near 1.0"
"threshold": ("FLOAT", {"default": 0.98, "min": 0.0, "max": 1.0, "step": 0.005}),
# What to do if no image matches:
"fallback": (["return_first", "return_last"], {"default": "return_last"}),
},
"optional": {
# Keep your requested defaults, but allow override if needed.
"start_x": ("INT", {"default": 0, "min": -999999, "max": 999999}),
"start_y": ("INT", {"default": 386, "min": -999999, "max": 999999}),
"end_x": ("INT", {"default": 330, "min": -999999, "max": 999999}),
"end_y": ("INT", {"default": 385, "min": -999999, "max": 999999}),
}
}
RETURN_TYPES = ("IMAGE", "INT")
RETURN_NAMES = ("image", "index")
FUNCTION = "pick"
CATEGORY = "image/filter"
def pick(
self,
images,
white_mode="max_channel",
threshold=0.98,
fallback="return_last",
start_x=0,
start_y=386,
end_x=330,
end_y=385,
):
# images: torch tensor [B,H,W,C], float in 0..1 (typical in ComfyUI)
if not isinstance(images, torch.Tensor):
raise TypeError("images must be a torch.Tensor (ComfyUI IMAGE type).")
if images.ndim != 4 or images.shape[-1] < 3:
raise ValueError(f"Expected images shape [B,H,W,C>=3], got {tuple(images.shape)}")
B, H, W, C = images.shape
# Build the integer pixel coordinates along the line
xs_list, ys_list = bresenham_line(int(start_x), int(start_y), int(end_x), int(end_y))
# Convert to tensors on same device
device = images.device
xs = torch.tensor(xs_list, device=device, dtype=torch.long)
ys = torch.tensor(ys_list, device=device, dtype=torch.long)
# If the line goes out of bounds, we treat it as "no match" for safety.
if xs.min().item() < 0 or ys.min().item() < 0 or xs.max().item() >= W or ys.max().item() >= H:
# fallback output
idx = 0 if fallback == "return_first" else max(B - 1, 0)
return (images[idx:idx+1], int(-1))
# Sample pixels along the line for every image in the batch: shape [B, N, 3]
pixels = images[:, ys, xs, :3]
if white_mode == "max_channel":
# VERY tolerant: any channel close to 1 counts as "white"
white_mask = pixels.max(dim=-1).values >= threshold
elif white_mode == "all_channels":
# Strict white: all channels must be close to 1
white_mask = pixels.min(dim=-1).values >= threshold
elif white_mode == "green_only":
# Only green channel must be high
white_mask = pixels[..., 1] >= threshold
elif white_mode == "luminance":
# Perceived brightness (Rec.709-ish)
weights = torch.tensor([0.2126, 0.7152, 0.0722], device=device, dtype=pixels.dtype)
lum = (pixels * weights).sum(dim=-1)
white_mask = lum >= threshold
else:
raise ValueError(f"Unknown white_mode: {white_mode}")
# Line is "completely white" if ALL pixels along line are white
line_is_white = white_mask.all(dim=1) # shape [B]
found = torch.nonzero(line_is_white, as_tuple=False).flatten()
if found.numel() > 0:
idx = int(found[0].item())
return (images[idx:idx+1], idx)
# No match: fallback
idx = 0 if fallback == "return_first" else max(B - 1, 0)
return (images[idx:idx+1], int(-1))
NODE_CLASS_MAPPINGS = {
"BatchRaycast_2D": BatchRaycast_2D
}
NODE_DISPLAY_NAME_MAPPINGS = {
"BatchRaycast_2D": "BatchRaycast_2D"
}
|