|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
|
|
|
|
import PIL.Image |
|
|
import torch |
|
|
|
|
|
from diffusers.utils import load_image |
|
|
from diffusers.utils.constants import ( |
|
|
DECODE_ENDPOINT_FLUX, |
|
|
DECODE_ENDPOINT_SD_V1, |
|
|
DECODE_ENDPOINT_SD_XL, |
|
|
ENCODE_ENDPOINT_FLUX, |
|
|
ENCODE_ENDPOINT_SD_V1, |
|
|
ENCODE_ENDPOINT_SD_XL, |
|
|
) |
|
|
from diffusers.utils.remote_utils import ( |
|
|
remote_decode, |
|
|
remote_encode, |
|
|
) |
|
|
from diffusers.utils.testing_utils import ( |
|
|
enable_full_determinism, |
|
|
slow, |
|
|
) |
|
|
|
|
|
|
|
|
enable_full_determinism() |
|
|
|
|
|
IMAGE = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true" |
|
|
|
|
|
|
|
|
class RemoteAutoencoderKLEncodeMixin: |
|
|
channels: int = None |
|
|
endpoint: str = None |
|
|
decode_endpoint: str = None |
|
|
dtype: torch.dtype = None |
|
|
scaling_factor: float = None |
|
|
shift_factor: float = None |
|
|
image: PIL.Image.Image = None |
|
|
|
|
|
def get_dummy_inputs(self): |
|
|
if self.image is None: |
|
|
self.image = load_image(IMAGE) |
|
|
inputs = { |
|
|
"endpoint": self.endpoint, |
|
|
"image": self.image, |
|
|
"scaling_factor": self.scaling_factor, |
|
|
"shift_factor": self.shift_factor, |
|
|
} |
|
|
return inputs |
|
|
|
|
|
def test_image_input(self): |
|
|
inputs = self.get_dummy_inputs() |
|
|
height, width = inputs["image"].height, inputs["image"].width |
|
|
output = remote_encode(**inputs) |
|
|
self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8]) |
|
|
decoded = remote_decode( |
|
|
tensor=output, |
|
|
endpoint=self.decode_endpoint, |
|
|
scaling_factor=self.scaling_factor, |
|
|
shift_factor=self.shift_factor, |
|
|
image_format="png", |
|
|
) |
|
|
self.assertEqual(decoded.height, height) |
|
|
self.assertEqual(decoded.width, width) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RemoteAutoencoderKLSDv1Tests( |
|
|
RemoteAutoencoderKLEncodeMixin, |
|
|
unittest.TestCase, |
|
|
): |
|
|
channels = 4 |
|
|
endpoint = ENCODE_ENDPOINT_SD_V1 |
|
|
decode_endpoint = DECODE_ENDPOINT_SD_V1 |
|
|
dtype = torch.float16 |
|
|
scaling_factor = 0.18215 |
|
|
shift_factor = None |
|
|
|
|
|
|
|
|
class RemoteAutoencoderKLSDXLTests( |
|
|
RemoteAutoencoderKLEncodeMixin, |
|
|
unittest.TestCase, |
|
|
): |
|
|
channels = 4 |
|
|
endpoint = ENCODE_ENDPOINT_SD_XL |
|
|
decode_endpoint = DECODE_ENDPOINT_SD_XL |
|
|
dtype = torch.float16 |
|
|
scaling_factor = 0.13025 |
|
|
shift_factor = None |
|
|
|
|
|
|
|
|
class RemoteAutoencoderKLFluxTests( |
|
|
RemoteAutoencoderKLEncodeMixin, |
|
|
unittest.TestCase, |
|
|
): |
|
|
channels = 16 |
|
|
endpoint = ENCODE_ENDPOINT_FLUX |
|
|
decode_endpoint = DECODE_ENDPOINT_FLUX |
|
|
dtype = torch.bfloat16 |
|
|
scaling_factor = 0.3611 |
|
|
shift_factor = 0.1159 |
|
|
|
|
|
|
|
|
class RemoteAutoencoderKLEncodeSlowTestMixin: |
|
|
channels: int = 4 |
|
|
endpoint: str = None |
|
|
decode_endpoint: str = None |
|
|
dtype: torch.dtype = None |
|
|
scaling_factor: float = None |
|
|
shift_factor: float = None |
|
|
image: PIL.Image.Image = None |
|
|
|
|
|
def get_dummy_inputs(self): |
|
|
if self.image is None: |
|
|
self.image = load_image(IMAGE) |
|
|
inputs = { |
|
|
"endpoint": self.endpoint, |
|
|
"image": self.image, |
|
|
"scaling_factor": self.scaling_factor, |
|
|
"shift_factor": self.shift_factor, |
|
|
} |
|
|
return inputs |
|
|
|
|
|
def test_multi_res(self): |
|
|
inputs = self.get_dummy_inputs() |
|
|
for height in { |
|
|
320, |
|
|
512, |
|
|
640, |
|
|
704, |
|
|
896, |
|
|
1024, |
|
|
1208, |
|
|
1384, |
|
|
1536, |
|
|
1608, |
|
|
1864, |
|
|
2048, |
|
|
}: |
|
|
for width in { |
|
|
320, |
|
|
512, |
|
|
640, |
|
|
704, |
|
|
896, |
|
|
1024, |
|
|
1208, |
|
|
1384, |
|
|
1536, |
|
|
1608, |
|
|
1864, |
|
|
2048, |
|
|
}: |
|
|
inputs["image"] = inputs["image"].resize( |
|
|
( |
|
|
width, |
|
|
height, |
|
|
) |
|
|
) |
|
|
output = remote_encode(**inputs) |
|
|
self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8]) |
|
|
decoded = remote_decode( |
|
|
tensor=output, |
|
|
endpoint=self.decode_endpoint, |
|
|
scaling_factor=self.scaling_factor, |
|
|
shift_factor=self.shift_factor, |
|
|
image_format="png", |
|
|
) |
|
|
self.assertEqual(decoded.height, height) |
|
|
self.assertEqual(decoded.width, width) |
|
|
decoded.save(f"test_multi_res_{height}_{width}.png") |
|
|
|
|
|
|
|
|
@slow |
|
|
class RemoteAutoencoderKLSDv1SlowTests( |
|
|
RemoteAutoencoderKLEncodeSlowTestMixin, |
|
|
unittest.TestCase, |
|
|
): |
|
|
endpoint = ENCODE_ENDPOINT_SD_V1 |
|
|
decode_endpoint = DECODE_ENDPOINT_SD_V1 |
|
|
dtype = torch.float16 |
|
|
scaling_factor = 0.18215 |
|
|
shift_factor = None |
|
|
|
|
|
|
|
|
@slow |
|
|
class RemoteAutoencoderKLSDXLSlowTests( |
|
|
RemoteAutoencoderKLEncodeSlowTestMixin, |
|
|
unittest.TestCase, |
|
|
): |
|
|
endpoint = ENCODE_ENDPOINT_SD_XL |
|
|
decode_endpoint = DECODE_ENDPOINT_SD_XL |
|
|
dtype = torch.float16 |
|
|
scaling_factor = 0.13025 |
|
|
shift_factor = None |
|
|
|
|
|
|
|
|
@slow |
|
|
class RemoteAutoencoderKLFluxSlowTests( |
|
|
RemoteAutoencoderKLEncodeSlowTestMixin, |
|
|
unittest.TestCase, |
|
|
): |
|
|
channels = 16 |
|
|
endpoint = ENCODE_ENDPOINT_FLUX |
|
|
decode_endpoint = DECODE_ENDPOINT_FLUX |
|
|
dtype = torch.bfloat16 |
|
|
scaling_factor = 0.3611 |
|
|
shift_factor = 0.1159 |
|
|
|