File size: 5,408 Bytes
f2ad0fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torch

def bresenham_line(x0: int, y0: int, x1: int, y1: int):
    """

    Integer Bresenham line algorithm.

    Returns two Python lists: xs, ys (same length).

    """
    xs = []
    ys = []

    dx = abs(x1 - x0)
    sx = 1 if x0 < x1 else -1
    dy = -abs(y1 - y0)
    sy = 1 if y0 < y1 else -1
    err = dx + dy  # error value e_xy

    x, y = x0, y0
    while True:
        xs.append(x)
        ys.append(y)
        if x == x1 and y == y1:
            break
        e2 = 2 * err
        if e2 >= dy:
            err += dy
            x += sx
        if e2 <= dx:
            err += dx
            y += sy

    return xs, ys


class BatchRaycast_2D:
    """

    Returns the first image in the batch where the START->END line

    is completely "white enough" according to the chosen mode/threshold.

    """

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "images": ("IMAGE",),

                # "White" detection behavior:
                # - max_channel: max(R,G,B) >= threshold  (VERY tolerant; rgb(0,1,0) passes)
                # - all_channels: min(R,G,B) >= threshold (strict white)
                # - luminance: dot([0.2126, 0.7152, 0.0722], RGB) >= threshold
                # - green_only: G >= threshold
                "white_mode": (["max_channel", "all_channels", "luminance", "green_only"], {"default": "max_channel"}),

                # For ComfyUI IMAGE tensors this is typically 0..1 float.
                # Example: threshold=0.98 means "near 1.0"
                "threshold": ("FLOAT", {"default": 0.98, "min": 0.0, "max": 1.0, "step": 0.005}),

                # What to do if no image matches:
                "fallback": (["return_first", "return_last"], {"default": "return_last"}),
            },
            "optional": {
                # Keep your requested defaults, but allow override if needed.
                "start_x": ("INT", {"default": 0, "min": -999999, "max": 999999}),
                "start_y": ("INT", {"default": 386, "min": -999999, "max": 999999}),
                "end_x": ("INT", {"default": 330, "min": -999999, "max": 999999}),
                "end_y": ("INT", {"default": 385, "min": -999999, "max": 999999}),
            }
        }

    RETURN_TYPES = ("IMAGE", "INT")
    RETURN_NAMES = ("image", "index")
    FUNCTION = "pick"
    CATEGORY = "image/filter"

    def pick(

        self,

        images,

        white_mode="max_channel",

        threshold=0.98,

        fallback="return_last",

        start_x=0,

        start_y=386,

        end_x=330,

        end_y=385,

    ):
        # images: torch tensor [B,H,W,C], float in 0..1 (typical in ComfyUI)
        if not isinstance(images, torch.Tensor):
            raise TypeError("images must be a torch.Tensor (ComfyUI IMAGE type).")

        if images.ndim != 4 or images.shape[-1] < 3:
            raise ValueError(f"Expected images shape [B,H,W,C>=3], got {tuple(images.shape)}")

        B, H, W, C = images.shape

        # Build the integer pixel coordinates along the line
        xs_list, ys_list = bresenham_line(int(start_x), int(start_y), int(end_x), int(end_y))

        # Convert to tensors on same device
        device = images.device
        xs = torch.tensor(xs_list, device=device, dtype=torch.long)
        ys = torch.tensor(ys_list, device=device, dtype=torch.long)

        # If the line goes out of bounds, we treat it as "no match" for safety.
        if xs.min().item() < 0 or ys.min().item() < 0 or xs.max().item() >= W or ys.max().item() >= H:
            # fallback output
            idx = 0 if fallback == "return_first" else max(B - 1, 0)
            return (images[idx:idx+1], int(-1))

        # Sample pixels along the line for every image in the batch: shape [B, N, 3]
        pixels = images[:, ys, xs, :3]

        if white_mode == "max_channel":
            # VERY tolerant: any channel close to 1 counts as "white"
            white_mask = pixels.max(dim=-1).values >= threshold

        elif white_mode == "all_channels":
            # Strict white: all channels must be close to 1
            white_mask = pixels.min(dim=-1).values >= threshold

        elif white_mode == "green_only":
            # Only green channel must be high
            white_mask = pixels[..., 1] >= threshold

        elif white_mode == "luminance":
            # Perceived brightness (Rec.709-ish)
            weights = torch.tensor([0.2126, 0.7152, 0.0722], device=device, dtype=pixels.dtype)
            lum = (pixels * weights).sum(dim=-1)
            white_mask = lum >= threshold

        else:
            raise ValueError(f"Unknown white_mode: {white_mode}")

        # Line is "completely white" if ALL pixels along line are white
        line_is_white = white_mask.all(dim=1)  # shape [B]

        found = torch.nonzero(line_is_white, as_tuple=False).flatten()
        if found.numel() > 0:
            idx = int(found[0].item())
            return (images[idx:idx+1], idx)

        # No match: fallback
        idx = 0 if fallback == "return_first" else max(B - 1, 0)
        return (images[idx:idx+1], int(-1))


NODE_CLASS_MAPPINGS = {
    "BatchRaycast_2D": BatchRaycast_2D
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "BatchRaycast_2D": "BatchRaycast_2D"
}