| | import torch |
| | from unittest.mock import patch, MagicMock |
| |
|
| | |
| | mock_nodes = MagicMock() |
| | mock_nodes.MAX_RESOLUTION = 16384 |
| |
|
| | |
| | mock_server = MagicMock() |
| |
|
| | with patch.dict('sys.modules', {'nodes': mock_nodes, 'server': mock_server}): |
| | from comfy_extras.nodes_images import ImageStitch |
| |
|
| |
|
| | class TestImageStitch: |
| |
|
| | def create_test_image(self, batch_size=1, height=64, width=64, channels=3): |
| | """Helper to create test images with specific dimensions""" |
| | return torch.rand(batch_size, height, width, channels) |
| |
|
| | def test_no_image2_passthrough(self): |
| | """Test that when image2 is None, image1 is returned unchanged""" |
| | node = ImageStitch() |
| | image1 = self.create_test_image() |
| |
|
| | result = node.stitch(image1, "right", True, 0, "white", image2=None) |
| |
|
| | assert len(result) == 1 |
| | assert torch.equal(result[0], image1) |
| |
|
| | def test_basic_horizontal_stitch_right(self): |
| | """Test basic horizontal stitching to the right""" |
| | node = ImageStitch() |
| | image1 = self.create_test_image(height=32, width=32) |
| | image2 = self.create_test_image(height=32, width=24) |
| |
|
| | result = node.stitch(image1, "right", False, 0, "white", image2) |
| |
|
| | assert result[0].shape == (1, 32, 56, 3) |
| |
|
| | def test_basic_horizontal_stitch_left(self): |
| | """Test basic horizontal stitching to the left""" |
| | node = ImageStitch() |
| | image1 = self.create_test_image(height=32, width=32) |
| | image2 = self.create_test_image(height=32, width=24) |
| |
|
| | result = node.stitch(image1, "left", False, 0, "white", image2) |
| |
|
| | assert result[0].shape == (1, 32, 56, 3) |
| |
|
| | def test_basic_vertical_stitch_down(self): |
| | """Test basic vertical stitching downward""" |
| | node = ImageStitch() |
| | image1 = self.create_test_image(height=32, width=32) |
| | image2 = self.create_test_image(height=24, width=32) |
| |
|
| | result = node.stitch(image1, "down", False, 0, "white", image2) |
| |
|
| | assert result[0].shape == (1, 56, 32, 3) |
| |
|
| | def test_basic_vertical_stitch_up(self): |
| | """Test basic vertical stitching upward""" |
| | node = ImageStitch() |
| | image1 = self.create_test_image(height=32, width=32) |
| | image2 = self.create_test_image(height=24, width=32) |
| |
|
| | result = node.stitch(image1, "up", False, 0, "white", image2) |
| |
|
| | assert result[0].shape == (1, 56, 32, 3) |
| |
|
| | def test_size_matching_horizontal(self): |
| | """Test size matching for horizontal concatenation""" |
| | node = ImageStitch() |
| | image1 = self.create_test_image(height=64, width=64) |
| | image2 = self.create_test_image(height=32, width=32) |
| |
|
| | result = node.stitch(image1, "right", True, 0, "white", image2) |
| |
|
| | |
| | expected_width = 64 + 64 |
| | assert result[0].shape == (1, 64, expected_width, 3) |
| |
|
| | def test_size_matching_vertical(self): |
| | """Test size matching for vertical concatenation""" |
| | node = ImageStitch() |
| | image1 = self.create_test_image(height=64, width=64) |
| | image2 = self.create_test_image(height=32, width=32) |
| |
|
| | result = node.stitch(image1, "down", True, 0, "white", image2) |
| |
|
| | |
| | expected_height = 64 + 64 |
| | assert result[0].shape == (1, expected_height, 64, 3) |
| |
|
| | def test_padding_for_mismatched_heights_horizontal(self): |
| | """Test padding when heights don't match in horizontal concatenation""" |
| | node = ImageStitch() |
| | image1 = self.create_test_image(height=64, width=32) |
| | image2 = self.create_test_image(height=48, width=24) |
| |
|
| | result = node.stitch(image1, "right", False, 0, "white", image2) |
| |
|
| | |
| | assert result[0].shape == (1, 64, 56, 3) |
| |
|
| | def test_padding_for_mismatched_widths_vertical(self): |
| | """Test padding when widths don't match in vertical concatenation""" |
| | node = ImageStitch() |
| | image1 = self.create_test_image(height=32, width=64) |
| | image2 = self.create_test_image(height=24, width=48) |
| |
|
| | result = node.stitch(image1, "down", False, 0, "white", image2) |
| |
|
| | |
| | assert result[0].shape == (1, 56, 64, 3) |
| |
|
| | def test_spacing_horizontal(self): |
| | """Test spacing addition in horizontal concatenation""" |
| | node = ImageStitch() |
| | image1 = self.create_test_image(height=32, width=32) |
| | image2 = self.create_test_image(height=32, width=24) |
| | spacing_width = 16 |
| |
|
| | result = node.stitch(image1, "right", False, spacing_width, "white", image2) |
| |
|
| | |
| | assert result[0].shape == (1, 32, 72, 3) |
| |
|
| | def test_spacing_vertical(self): |
| | """Test spacing addition in vertical concatenation""" |
| | node = ImageStitch() |
| | image1 = self.create_test_image(height=32, width=32) |
| | image2 = self.create_test_image(height=24, width=32) |
| | spacing_width = 16 |
| |
|
| | result = node.stitch(image1, "down", False, spacing_width, "white", image2) |
| |
|
| | |
| | assert result[0].shape == (1, 72, 32, 3) |
| |
|
| | def test_spacing_color_values(self): |
| | """Test that spacing colors are applied correctly""" |
| | node = ImageStitch() |
| | image1 = self.create_test_image(height=32, width=32) |
| | image2 = self.create_test_image(height=32, width=32) |
| |
|
| | |
| | result_white = node.stitch(image1, "right", False, 16, "white", image2) |
| | |
| | spacing_region = result_white[0][:, :, 32:48, :] |
| | assert torch.all(spacing_region >= 0.9) |
| |
|
| | |
| | result_black = node.stitch(image1, "right", False, 16, "black", image2) |
| | spacing_region = result_black[0][:, :, 32:48, :] |
| | assert torch.all(spacing_region <= 0.1) |
| |
|
| | def test_odd_spacing_width_made_even(self): |
| | """Test that odd spacing widths are made even""" |
| | node = ImageStitch() |
| | image1 = self.create_test_image(height=32, width=32) |
| | image2 = self.create_test_image(height=32, width=32) |
| |
|
| | |
| | result = node.stitch(image1, "right", False, 15, "white", image2) |
| |
|
| | |
| | assert result[0].shape == (1, 32, 80, 3) |
| |
|
| | def test_batch_size_matching(self): |
| | """Test that different batch sizes are handled correctly""" |
| | node = ImageStitch() |
| | image1 = self.create_test_image(batch_size=2, height=32, width=32) |
| | image2 = self.create_test_image(batch_size=1, height=32, width=32) |
| |
|
| | result = node.stitch(image1, "right", False, 0, "white", image2) |
| |
|
| | |
| | assert result[0].shape == (2, 32, 64, 3) |
| |
|
| | def test_channel_matching_rgb_to_rgba(self): |
| | """Test that channel differences are handled (RGB + alpha)""" |
| | node = ImageStitch() |
| | image1 = self.create_test_image(channels=3) |
| | image2 = self.create_test_image(channels=4) |
| |
|
| | result = node.stitch(image1, "right", False, 0, "white", image2) |
| |
|
| | |
| | assert result[0].shape[-1] == 4 |
| |
|
| | def test_channel_matching_rgba_to_rgb(self): |
| | """Test that channel differences are handled (RGBA + RGB)""" |
| | node = ImageStitch() |
| | image1 = self.create_test_image(channels=4) |
| | image2 = self.create_test_image(channels=3) |
| |
|
| | result = node.stitch(image1, "right", False, 0, "white", image2) |
| |
|
| | |
| | assert result[0].shape[-1] == 4 |
| |
|
| | def test_all_color_options(self): |
| | """Test all available color options""" |
| | node = ImageStitch() |
| | image1 = self.create_test_image(height=32, width=32) |
| | image2 = self.create_test_image(height=32, width=32) |
| |
|
| | colors = ["white", "black", "red", "green", "blue"] |
| |
|
| | for color in colors: |
| | result = node.stitch(image1, "right", False, 16, color, image2) |
| | assert result[0].shape == (1, 32, 80, 3) |
| |
|
| | def test_all_directions(self): |
| | """Test all direction options""" |
| | node = ImageStitch() |
| | image1 = self.create_test_image(height=32, width=32) |
| | image2 = self.create_test_image(height=32, width=32) |
| |
|
| | directions = ["right", "left", "up", "down"] |
| |
|
| | for direction in directions: |
| | result = node.stitch(image1, direction, False, 0, "white", image2) |
| | assert result[0].shape == (1, 32, 64, 3) if direction in ["right", "left"] else (1, 64, 32, 3) |
| |
|
| | def test_batch_size_channel_spacing_integration(self): |
| | """Test integration of batch matching, channel matching, size matching, and spacings""" |
| | node = ImageStitch() |
| | image1 = self.create_test_image(batch_size=2, height=64, width=48, channels=3) |
| | image2 = self.create_test_image(batch_size=1, height=32, width=32, channels=4) |
| |
|
| | result = node.stitch(image1, "right", True, 8, "red", image2) |
| |
|
| | |
| | assert result[0].shape[0] == 2 |
| | assert result[0].shape[-1] == 4 |
| | assert result[0].shape[1] == 64 |
| | |
| | expected_image2_width = int(64 * (32/32)) |
| | expected_total_width = 48 + 8 + expected_image2_width |
| | assert result[0].shape[2] == expected_total_width |
| |
|
| |
|