File size: 3,004 Bytes
836f12e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch


class SlideCropBatch40:
    """

    Create a 40-frame horizontal sliding crop batch from a single 3025x1024 image.



    Output:

        IMAGE tensor with shape [40, 1024, 1536, C]



    Notes:

    - The first crop is x = 0..1535 inclusive.

    - For a 1536-wide crop taken from a 3025-wide image, the last valid crop must start

      at x = 1489 and end at x = 3024 inclusive.

    - That means the exact per-frame shift over 40 frames cannot be both constant and integer,

      because 1489 / 39 is not an integer.

    - This node therefore uses the nearest integer positions that are evenly spaced from

      0 to 1489 inclusive.

    """

    CATEGORY = "image/animation"
    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "make_batch"

    FRAME_COUNT = 40
    INPUT_WIDTH = 3025
    INPUT_HEIGHT = 1024
    CROP_WIDTH = 1536
    CROP_HEIGHT = 1024

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "image": ("IMAGE",),
            }
        }

    @classmethod
    def _positions(cls):
        intervals = cls.FRAME_COUNT - 1
        max_shift = cls.INPUT_WIDTH - cls.CROP_WIDTH  # 1489
        # Integer-only nearest rounding of i * max_shift / intervals.
        return [((i * max_shift) + intervals // 2) // intervals for i in range(cls.FRAME_COUNT)]

    def make_batch(self, image: torch.Tensor):
        if not isinstance(image, torch.Tensor):
            raise TypeError("Expected IMAGE input as a torch.Tensor.")

        if image.ndim != 4:
            raise ValueError(f"Expected IMAGE tensor with shape [B,H,W,C], got shape {tuple(image.shape)}")

        batch, height, width, channels = image.shape

        if batch != 1:
            raise ValueError(
                f"This node expects exactly 1 input image (batch size 1), but got batch size {batch}."
            )

        if height != self.INPUT_HEIGHT or width != self.INPUT_WIDTH:
            raise ValueError(
                f"Expected input resolution {self.INPUT_WIDTH}x{self.INPUT_HEIGHT}, "
                f"but got {width}x{height}."
            )

        if channels < 1:
            raise ValueError(f"Expected at least 1 channel, got {channels}.")

        single = image[0]  # [H, W, C]
        crops = []

        for x in self._positions():
            crop = single[:, x:x + self.CROP_WIDTH, :]
            if crop.shape[1] != self.CROP_WIDTH or crop.shape[0] != self.CROP_HEIGHT:
                raise RuntimeError(
                    f"Invalid crop at x={x}: got shape {tuple(crop.shape)}; "
                    f"expected [{self.CROP_HEIGHT}, {self.CROP_WIDTH}, C]."
                )
            crops.append(crop)

        output = torch.stack(crops, dim=0)  # [40, 1024, 1536, C]
        return (output,)


NODE_CLASS_MAPPINGS = {
    "SlideCropBatch40": SlideCropBatch40,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "SlideCropBatch40": "Slide Crop Batch 40",
}