""" Tests for pose image shape handling throughout the pipeline. This test file verifies that: 1. draw_pose returns the correct shape when grayscale=True 2. numpy_to_torch handles different input shapes correctly 3. The expected tensor shapes for the model are correct """ import numpy as np import pytest import torch class TestDrawPoseShapes: """Tests for draw_pose output shapes.""" def test_draw_pose_grayscale_returns_2d_array(self): """draw_pose with grayscale=True should return shape (H, W), not (H, W, 1).""" from fashn_vton.dwpose import draw_pose from fashn_vton.utils import get_dummy_dw_keypoints H, W = 512, 384 dummy_pose = get_dummy_dw_keypoints() result = draw_pose(dummy_pose, H, W, canvas_value=0, grayscale=True) # Grayscale should return 2D array (H, W) assert result.ndim == 2, f"Expected 2D array, got {result.ndim}D with shape {result.shape}" assert result.shape == (H, W), f"Expected shape ({H}, {W}), got {result.shape}" assert result.dtype == np.uint8, f"Expected dtype uint8, got {result.dtype}" def test_draw_pose_rgb_returns_3d_array(self): """draw_pose with grayscale=False should return shape (H, W, 3).""" from fashn_vton.dwpose import draw_pose from fashn_vton.utils import get_dummy_dw_keypoints H, W = 512, 384 dummy_pose = get_dummy_dw_keypoints() result = draw_pose(dummy_pose, H, W, canvas_value=0, grayscale=False) # RGB should return 3D array (H, W, 3) assert result.ndim == 3, f"Expected 3D array, got {result.ndim}D with shape {result.shape}" assert result.shape == (H, W, 3), f"Expected shape ({H}, {W}, 3), got {result.shape}" assert result.dtype == np.uint8, f"Expected dtype uint8, got {result.dtype}" class TestNumpyToTorchShapes: """Tests for numpy_to_torch conversion behavior.""" def test_numpy_to_torch_3d_rgb_image(self): """numpy_to_torch should convert (H, W, C) to (C, H, W).""" from fashn_vton.utils import numpy_to_torch H, W, C = 512, 384, 3 img = np.zeros((H, W, C), dtype=np.uint8) result = numpy_to_torch(img) assert result.shape == (C, H, W), f"Expected shape ({C}, {H}, {W}), got {result.shape}" def test_numpy_to_torch_2d_grayscale_passes_through(self): """numpy_to_torch with 2D input should pass through without permutation. The pipeline's numpy_to_torch checks ndim == 3 before permuting, so 2D grayscale images pass through unchanged. """ from fashn_vton.utils import numpy_to_torch H, W = 512, 384 img = np.zeros((H, W), dtype=np.uint8) # 2D grayscale result = numpy_to_torch(img) # 2D input should pass through unchanged assert result.shape == (H, W), f"Expected shape ({H}, {W}), got {result.shape}" assert result.ndim == 2, f"Expected 2D tensor, got {result.ndim}D" def test_common_numpy_to_torch_with_2d_input(self): """Test the common package's numpy_to_torch with 2D input. The common package has: if permute and image.ndim == 3: ... So it should handle 2D gracefully by not permuting. """ # Replicate the common package's numpy_to_torch logic def common_numpy_to_torch(image: np.ndarray, permute: bool = True) -> torch.Tensor: image = torch.from_numpy(image) if permute and image.ndim == 3: image = image.permute(2, 0, 1) return image H, W = 512, 384 img = np.zeros((H, W), dtype=np.uint8) # 2D grayscale result = common_numpy_to_torch(img) # With 2D input, no permutation should happen assert result.shape == (H, W), f"Expected shape ({H}, {W}), got {result.shape}" assert result.ndim == 2, f"Expected 2D tensor, got {result.ndim}D" class TestExpectedModelTensorShapes: """Tests for expected tensor shapes going into the model.""" def test_expected_pose_tensor_shape_single_sample(self): """Verify the expected pose tensor shape for a single sample. Based on the model architecture: - x_embedder expects: [x (3), ca_images (3), person_poses (1)] = 7 channels - garment_embedder expects: [garment_images (3), garment_poses (1)] = 4 channels So poses should be (batch, 1, H, W). """ batch_size = 1 H, W = 768, 576 # Model target size # Expected pose tensor shape expected_pose_shape = (batch_size, 1, H, W) # Create a grayscale pose image as draw_pose would return grayscale_pose = np.zeros((H, W), dtype=np.uint8) # The CORRECT way to convert: unsqueeze to add channel dim tensor = torch.from_numpy(grayscale_pose).unsqueeze(0).unsqueeze(0) assert tensor.shape == expected_pose_shape, \ f"Expected shape {expected_pose_shape}, got {tensor.shape}" def test_expected_pose_tensor_shape_multi_sample(self): """Verify pose tensor shape for multiple samples.""" batch_size = 4 H, W = 768, 576 expected_pose_shape = (batch_size, 1, H, W) grayscale_pose = np.zeros((H, W), dtype=np.uint8) tensor = torch.from_numpy(grayscale_pose).unsqueeze(0).unsqueeze(0) tensor = tensor.repeat(batch_size, 1, 1, 1) assert tensor.shape == expected_pose_shape, \ f"Expected shape {expected_pose_shape}, got {tensor.shape}" def test_pipeline_prepare_tensor_grayscale(self): """Test the prepare_tensor logic from pipeline.py with grayscale input. The fixed numpy_to_torch handles 2D input correctly, so prepare_tensor doesn't need special grayscale handling anymore. """ from fashn_vton.utils import normalize_uint8_to_neg1_1, numpy_to_torch H, W = 768, 576 num_samples = 2 # Grayscale pose image as draw_pose returns grayscale_pose = np.zeros((H, W), dtype=np.uint8) # The prepare_tensor function from pipeline.py (fixed version) def prepare_tensor(img: np.ndarray) -> torch.Tensor: t = numpy_to_torch(img).unsqueeze(0) # (H, W) -> (1, H, W) t = normalize_uint8_to_neg1_1(t) t = t.repeat(num_samples, 1, 1, 1) # Prepends dim: (1, H, W) -> (N, 1, H, W) return t result = prepare_tensor(grayscale_pose) expected_shape = (num_samples, 1, H, W) assert result.shape == expected_shape, \ f"Expected shape {expected_shape}, got {result.shape}" if __name__ == "__main__": pytest.main([__file__, "-v"])