| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import unittest |
| |
|
| | import PIL.Image |
| | import torch |
| |
|
| | from diffusers.utils import load_image |
| | from diffusers.utils.constants import ( |
| | DECODE_ENDPOINT_FLUX, |
| | DECODE_ENDPOINT_SD_V1, |
| | DECODE_ENDPOINT_SD_XL, |
| | ENCODE_ENDPOINT_FLUX, |
| | ENCODE_ENDPOINT_SD_V1, |
| | ENCODE_ENDPOINT_SD_XL, |
| | ) |
| | from diffusers.utils.remote_utils import ( |
| | remote_decode, |
| | remote_encode, |
| | ) |
| |
|
| | from ..testing_utils import ( |
| | enable_full_determinism, |
| | slow, |
| | ) |
| |
|
| |
|
| | enable_full_determinism() |
| |
|
| | IMAGE = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true" |
| |
|
| |
|
| | class RemoteAutoencoderKLEncodeMixin: |
| | channels: int = None |
| | endpoint: str = None |
| | decode_endpoint: str = None |
| | dtype: torch.dtype = None |
| | scaling_factor: float = None |
| | shift_factor: float = None |
| | image: PIL.Image.Image = None |
| |
|
| | def get_dummy_inputs(self): |
| | if self.image is None: |
| | self.image = load_image(IMAGE) |
| | inputs = { |
| | "endpoint": self.endpoint, |
| | "image": self.image, |
| | "scaling_factor": self.scaling_factor, |
| | "shift_factor": self.shift_factor, |
| | } |
| | return inputs |
| |
|
| | def test_image_input(self): |
| | inputs = self.get_dummy_inputs() |
| | height, width = inputs["image"].height, inputs["image"].width |
| | output = remote_encode(**inputs) |
| | self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8]) |
| | decoded = remote_decode( |
| | tensor=output, |
| | endpoint=self.decode_endpoint, |
| | scaling_factor=self.scaling_factor, |
| | shift_factor=self.shift_factor, |
| | image_format="png", |
| | ) |
| | self.assertEqual(decoded.height, height) |
| | self.assertEqual(decoded.width, width) |
| | |
| | |
| | |
| |
|
| |
|
| | class RemoteAutoencoderKLSDv1Tests( |
| | RemoteAutoencoderKLEncodeMixin, |
| | unittest.TestCase, |
| | ): |
| | channels = 4 |
| | endpoint = ENCODE_ENDPOINT_SD_V1 |
| | decode_endpoint = DECODE_ENDPOINT_SD_V1 |
| | dtype = torch.float16 |
| | scaling_factor = 0.18215 |
| | shift_factor = None |
| |
|
| |
|
| | class RemoteAutoencoderKLSDXLTests( |
| | RemoteAutoencoderKLEncodeMixin, |
| | unittest.TestCase, |
| | ): |
| | channels = 4 |
| | endpoint = ENCODE_ENDPOINT_SD_XL |
| | decode_endpoint = DECODE_ENDPOINT_SD_XL |
| | dtype = torch.float16 |
| | scaling_factor = 0.13025 |
| | shift_factor = None |
| |
|
| |
|
| | class RemoteAutoencoderKLFluxTests( |
| | RemoteAutoencoderKLEncodeMixin, |
| | unittest.TestCase, |
| | ): |
| | channels = 16 |
| | endpoint = ENCODE_ENDPOINT_FLUX |
| | decode_endpoint = DECODE_ENDPOINT_FLUX |
| | dtype = torch.bfloat16 |
| | scaling_factor = 0.3611 |
| | shift_factor = 0.1159 |
| |
|
| |
|
| | class RemoteAutoencoderKLEncodeSlowTestMixin: |
| | channels: int = 4 |
| | endpoint: str = None |
| | decode_endpoint: str = None |
| | dtype: torch.dtype = None |
| | scaling_factor: float = None |
| | shift_factor: float = None |
| | image: PIL.Image.Image = None |
| |
|
| | def get_dummy_inputs(self): |
| | if self.image is None: |
| | self.image = load_image(IMAGE) |
| | inputs = { |
| | "endpoint": self.endpoint, |
| | "image": self.image, |
| | "scaling_factor": self.scaling_factor, |
| | "shift_factor": self.shift_factor, |
| | } |
| | return inputs |
| |
|
| | def test_multi_res(self): |
| | inputs = self.get_dummy_inputs() |
| | for height in { |
| | 320, |
| | 512, |
| | 640, |
| | 704, |
| | 896, |
| | 1024, |
| | 1208, |
| | 1384, |
| | 1536, |
| | 1608, |
| | 1864, |
| | 2048, |
| | }: |
| | for width in { |
| | 320, |
| | 512, |
| | 640, |
| | 704, |
| | 896, |
| | 1024, |
| | 1208, |
| | 1384, |
| | 1536, |
| | 1608, |
| | 1864, |
| | 2048, |
| | }: |
| | inputs["image"] = inputs["image"].resize( |
| | ( |
| | width, |
| | height, |
| | ) |
| | ) |
| | output = remote_encode(**inputs) |
| | self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8]) |
| | decoded = remote_decode( |
| | tensor=output, |
| | endpoint=self.decode_endpoint, |
| | scaling_factor=self.scaling_factor, |
| | shift_factor=self.shift_factor, |
| | image_format="png", |
| | ) |
| | self.assertEqual(decoded.height, height) |
| | self.assertEqual(decoded.width, width) |
| | decoded.save(f"test_multi_res_{height}_{width}.png") |
| |
|
| |
|
| | @slow |
| | class RemoteAutoencoderKLSDv1SlowTests( |
| | RemoteAutoencoderKLEncodeSlowTestMixin, |
| | unittest.TestCase, |
| | ): |
| | endpoint = ENCODE_ENDPOINT_SD_V1 |
| | decode_endpoint = DECODE_ENDPOINT_SD_V1 |
| | dtype = torch.float16 |
| | scaling_factor = 0.18215 |
| | shift_factor = None |
| |
|
| |
|
| | @slow |
| | class RemoteAutoencoderKLSDXLSlowTests( |
| | RemoteAutoencoderKLEncodeSlowTestMixin, |
| | unittest.TestCase, |
| | ): |
| | endpoint = ENCODE_ENDPOINT_SD_XL |
| | decode_endpoint = DECODE_ENDPOINT_SD_XL |
| | dtype = torch.float16 |
| | scaling_factor = 0.13025 |
| | shift_factor = None |
| |
|
| |
|
| | @slow |
| | class RemoteAutoencoderKLFluxSlowTests( |
| | RemoteAutoencoderKLEncodeSlowTestMixin, |
| | unittest.TestCase, |
| | ): |
| | channels = 16 |
| | endpoint = ENCODE_ENDPOINT_FLUX |
| | decode_endpoint = DECODE_ENDPOINT_FLUX |
| | dtype = torch.bfloat16 |
| | scaling_factor = 0.3611 |
| | shift_factor = 0.1159 |
| |
|