| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import unittest |
| |
|
| | import numpy as np |
| | import PIL.Image |
| | import torch |
| |
|
| | from diffusers.image_processor import VaeImageProcessor |
| |
|
| |
|
| | class ImageProcessorTest(unittest.TestCase): |
| | @property |
| | def dummy_sample(self): |
| | batch_size = 1 |
| | num_channels = 3 |
| | height = 8 |
| | width = 8 |
| |
|
| | sample = torch.rand((batch_size, num_channels, height, width)) |
| |
|
| | return sample |
| |
|
| | @property |
| | def dummy_mask(self): |
| | batch_size = 1 |
| | num_channels = 1 |
| | height = 8 |
| | width = 8 |
| |
|
| | sample = torch.rand((batch_size, num_channels, height, width)) |
| |
|
| | return sample |
| |
|
| | def to_np(self, image): |
| | if isinstance(image[0], PIL.Image.Image): |
| | return np.stack([np.array(i) for i in image], axis=0) |
| | elif isinstance(image, torch.Tensor): |
| | return image.cpu().numpy().transpose(0, 2, 3, 1) |
| | return image |
| |
|
| | def test_vae_image_processor_pt(self): |
| | image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) |
| |
|
| | input_pt = self.dummy_sample |
| | input_np = self.to_np(input_pt) |
| |
|
| | for output_type in ["pt", "np", "pil"]: |
| | out = image_processor.postprocess( |
| | image_processor.preprocess(input_pt), |
| | output_type=output_type, |
| | ) |
| | out_np = self.to_np(out) |
| | in_np = (input_np * 255).round() if output_type == "pil" else input_np |
| | assert ( |
| | np.abs(in_np - out_np).max() < 1e-6 |
| | ), f"decoded output does not match input for output_type {output_type}" |
| |
|
| | def test_vae_image_processor_np(self): |
| | image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) |
| | input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1) |
| |
|
| | for output_type in ["pt", "np", "pil"]: |
| | out = image_processor.postprocess(image_processor.preprocess(input_np), output_type=output_type) |
| |
|
| | out_np = self.to_np(out) |
| | in_np = (input_np * 255).round() if output_type == "pil" else input_np |
| | assert ( |
| | np.abs(in_np - out_np).max() < 1e-6 |
| | ), f"decoded output does not match input for output_type {output_type}" |
| |
|
| | def test_vae_image_processor_pil(self): |
| | image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) |
| |
|
| | input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1) |
| | input_pil = image_processor.numpy_to_pil(input_np) |
| |
|
| | for output_type in ["pt", "np", "pil"]: |
| | out = image_processor.postprocess(image_processor.preprocess(input_pil), output_type=output_type) |
| | for i, o in zip(input_pil, out): |
| | in_np = np.array(i) |
| | out_np = self.to_np(out) if output_type == "pil" else (self.to_np(out) * 255).round() |
| | assert ( |
| | np.abs(in_np - out_np).max() < 1e-6 |
| | ), f"decoded output does not match input for output_type {output_type}" |
| |
|
| | def test_preprocess_input_3d(self): |
| | image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) |
| |
|
| | input_pt_4d = self.dummy_sample |
| | input_pt_3d = input_pt_4d.squeeze(0) |
| |
|
| | out_pt_4d = image_processor.postprocess( |
| | image_processor.preprocess(input_pt_4d), |
| | output_type="np", |
| | ) |
| | out_pt_3d = image_processor.postprocess( |
| | image_processor.preprocess(input_pt_3d), |
| | output_type="np", |
| | ) |
| |
|
| | input_np_4d = self.to_np(self.dummy_sample) |
| | input_np_3d = input_np_4d.squeeze(0) |
| |
|
| | out_np_4d = image_processor.postprocess( |
| | image_processor.preprocess(input_np_4d), |
| | output_type="np", |
| | ) |
| | out_np_3d = image_processor.postprocess( |
| | image_processor.preprocess(input_np_3d), |
| | output_type="np", |
| | ) |
| |
|
| | assert np.abs(out_pt_4d - out_pt_3d).max() < 1e-6 |
| | assert np.abs(out_np_4d - out_np_3d).max() < 1e-6 |
| |
|
| | def test_preprocess_input_list(self): |
| | image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) |
| |
|
| | input_pt_4d = self.dummy_sample |
| | input_pt_list = list(input_pt_4d) |
| |
|
| | out_pt_4d = image_processor.postprocess( |
| | image_processor.preprocess(input_pt_4d), |
| | output_type="np", |
| | ) |
| |
|
| | out_pt_list = image_processor.postprocess( |
| | image_processor.preprocess(input_pt_list), |
| | output_type="np", |
| | ) |
| |
|
| | input_np_4d = self.to_np(self.dummy_sample) |
| | input_np_list = list(input_np_4d) |
| |
|
| | out_np_4d = image_processor.postprocess( |
| | image_processor.preprocess(input_np_4d), |
| | output_type="np", |
| | ) |
| |
|
| | out_np_list = image_processor.postprocess( |
| | image_processor.preprocess(input_np_list), |
| | output_type="np", |
| | ) |
| |
|
| | assert np.abs(out_pt_4d - out_pt_list).max() < 1e-6 |
| | assert np.abs(out_np_4d - out_np_list).max() < 1e-6 |
| |
|
| | def test_preprocess_input_mask_3d(self): |
| | image_processor = VaeImageProcessor( |
| | do_resize=False, do_normalize=False, do_binarize=True, do_convert_grayscale=True |
| | ) |
| |
|
| | input_pt_4d = self.dummy_mask |
| | input_pt_3d = input_pt_4d.squeeze(0) |
| | input_pt_2d = input_pt_3d.squeeze(0) |
| |
|
| | out_pt_4d = image_processor.postprocess( |
| | image_processor.preprocess(input_pt_4d), |
| | output_type="np", |
| | ) |
| | out_pt_3d = image_processor.postprocess( |
| | image_processor.preprocess(input_pt_3d), |
| | output_type="np", |
| | ) |
| |
|
| | out_pt_2d = image_processor.postprocess( |
| | image_processor.preprocess(input_pt_2d), |
| | output_type="np", |
| | ) |
| |
|
| | input_np_4d = self.to_np(self.dummy_mask) |
| | input_np_3d = input_np_4d.squeeze(0) |
| | input_np_3d_1 = input_np_4d.squeeze(-1) |
| | input_np_2d = input_np_3d.squeeze(-1) |
| |
|
| | out_np_4d = image_processor.postprocess( |
| | image_processor.preprocess(input_np_4d), |
| | output_type="np", |
| | ) |
| | out_np_3d = image_processor.postprocess( |
| | image_processor.preprocess(input_np_3d), |
| | output_type="np", |
| | ) |
| |
|
| | out_np_3d_1 = image_processor.postprocess( |
| | image_processor.preprocess(input_np_3d_1), |
| | output_type="np", |
| | ) |
| |
|
| | out_np_2d = image_processor.postprocess( |
| | image_processor.preprocess(input_np_2d), |
| | output_type="np", |
| | ) |
| |
|
| | assert np.abs(out_pt_4d - out_pt_3d).max() == 0 |
| | assert np.abs(out_pt_4d - out_pt_2d).max() == 0 |
| | assert np.abs(out_np_4d - out_np_3d).max() == 0 |
| | assert np.abs(out_np_4d - out_np_3d_1).max() == 0 |
| | assert np.abs(out_np_4d - out_np_2d).max() == 0 |
| |
|
| | def test_preprocess_input_mask_list(self): |
| | image_processor = VaeImageProcessor(do_resize=False, do_normalize=False, do_convert_grayscale=True) |
| |
|
| | input_pt_4d = self.dummy_mask |
| | input_pt_3d = input_pt_4d.squeeze(0) |
| | input_pt_2d = input_pt_3d.squeeze(0) |
| |
|
| | inputs_pt = [input_pt_4d, input_pt_3d, input_pt_2d] |
| | inputs_pt_list = [[input_pt] for input_pt in inputs_pt] |
| |
|
| | for input_pt, input_pt_list in zip(inputs_pt, inputs_pt_list): |
| | out_pt = image_processor.postprocess( |
| | image_processor.preprocess(input_pt), |
| | output_type="np", |
| | ) |
| | out_pt_list = image_processor.postprocess( |
| | image_processor.preprocess(input_pt_list), |
| | output_type="np", |
| | ) |
| | assert np.abs(out_pt - out_pt_list).max() < 1e-6 |
| |
|
| | input_np_4d = self.to_np(self.dummy_mask) |
| | input_np_3d = input_np_4d.squeeze(0) |
| | input_np_2d = input_np_3d.squeeze(-1) |
| |
|
| | inputs_np = [input_np_4d, input_np_3d, input_np_2d] |
| | inputs_np_list = [[input_np] for input_np in inputs_np] |
| |
|
| | for input_np, input_np_list in zip(inputs_np, inputs_np_list): |
| | out_np = image_processor.postprocess( |
| | image_processor.preprocess(input_np), |
| | output_type="np", |
| | ) |
| | out_np_list = image_processor.postprocess( |
| | image_processor.preprocess(input_np_list), |
| | output_type="np", |
| | ) |
| | assert np.abs(out_np - out_np_list).max() < 1e-6 |
| |
|
| | def test_preprocess_input_mask_3d_batch(self): |
| | image_processor = VaeImageProcessor(do_resize=False, do_normalize=False, do_convert_grayscale=True) |
| |
|
| | |
| | dummy_mask_batch = torch.cat([self.dummy_mask] * 2, axis=0) |
| |
|
| | |
| | input_pt_3d = dummy_mask_batch.squeeze(1) |
| | input_np_3d = self.to_np(dummy_mask_batch).squeeze(-1) |
| |
|
| | input_pt_3d_list = list(input_pt_3d) |
| | input_np_3d_list = list(input_np_3d) |
| |
|
| | out_pt_3d = image_processor.postprocess( |
| | image_processor.preprocess(input_pt_3d), |
| | output_type="np", |
| | ) |
| | out_pt_3d_list = image_processor.postprocess( |
| | image_processor.preprocess(input_pt_3d_list), |
| | output_type="np", |
| | ) |
| |
|
| | assert np.abs(out_pt_3d - out_pt_3d_list).max() < 1e-6 |
| |
|
| | out_np_3d = image_processor.postprocess( |
| | image_processor.preprocess(input_np_3d), |
| | output_type="np", |
| | ) |
| | out_np_3d_list = image_processor.postprocess( |
| | image_processor.preprocess(input_np_3d_list), |
| | output_type="np", |
| | ) |
| |
|
| | assert np.abs(out_np_3d - out_np_3d_list).max() < 1e-6 |
| |
|
| | def test_vae_image_processor_resize_pt(self): |
| | image_processor = VaeImageProcessor(do_resize=True, vae_scale_factor=1) |
| | input_pt = self.dummy_sample |
| | b, c, h, w = input_pt.shape |
| | scale = 2 |
| | out_pt = image_processor.resize(image=input_pt, height=h // scale, width=w // scale) |
| | exp_pt_shape = (b, c, h // scale, w // scale) |
| | assert ( |
| | out_pt.shape == exp_pt_shape |
| | ), f"resized image output shape '{out_pt.shape}' didn't match expected shape '{exp_pt_shape}'." |
| |
|
| | def test_vae_image_processor_resize_np(self): |
| | image_processor = VaeImageProcessor(do_resize=True, vae_scale_factor=1) |
| | input_pt = self.dummy_sample |
| | b, c, h, w = input_pt.shape |
| | scale = 2 |
| | input_np = self.to_np(input_pt) |
| | out_np = image_processor.resize(image=input_np, height=h // scale, width=w // scale) |
| | exp_np_shape = (b, h // scale, w // scale, c) |
| | assert ( |
| | out_np.shape == exp_np_shape |
| | ), f"resized image output shape '{out_np.shape}' didn't match expected shape '{exp_np_shape}'." |
| |
|