| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import json |
| | import tempfile |
| | import unittest |
| | from pathlib import Path |
| |
|
| | from diffusers import ( |
| | DDIMScheduler, |
| | DDPMScheduler, |
| | DPMSolverMultistepScheduler, |
| | EulerAncestralDiscreteScheduler, |
| | EulerDiscreteScheduler, |
| | PNDMScheduler, |
| | logging, |
| | ) |
| | from diffusers.configuration_utils import ConfigMixin, register_to_config |
| |
|
| | from ..testing_utils import CaptureLogger |
| |
|
| |
|
| | class SampleObject(ConfigMixin): |
| | config_name = "config.json" |
| |
|
| | @register_to_config |
| | def __init__( |
| | self, |
| | a=2, |
| | b=5, |
| | c=(2, 5), |
| | d="for diffusion", |
| | e=[1, 3], |
| | ): |
| | pass |
| |
|
| |
|
| | class SampleObject2(ConfigMixin): |
| | config_name = "config.json" |
| |
|
| | @register_to_config |
| | def __init__( |
| | self, |
| | a=2, |
| | b=5, |
| | c=(2, 5), |
| | d="for diffusion", |
| | f=[1, 3], |
| | ): |
| | pass |
| |
|
| |
|
| | class SampleObject3(ConfigMixin): |
| | config_name = "config.json" |
| |
|
| | @register_to_config |
| | def __init__( |
| | self, |
| | a=2, |
| | b=5, |
| | c=(2, 5), |
| | d="for diffusion", |
| | e=[1, 3], |
| | f=[1, 3], |
| | ): |
| | pass |
| |
|
| |
|
| | class SampleObject4(ConfigMixin): |
| | config_name = "config.json" |
| |
|
| | @register_to_config |
| | def __init__( |
| | self, |
| | a=2, |
| | b=5, |
| | c=(2, 5), |
| | d="for diffusion", |
| | e=[1, 5], |
| | f=[5, 4], |
| | ): |
| | pass |
| |
|
| |
|
| | class SampleObjectPaths(ConfigMixin): |
| | config_name = "config.json" |
| |
|
| | @register_to_config |
| | def __init__(self, test_file_1=Path("foo/bar"), test_file_2=Path("foo bar\\bar")): |
| | pass |
| |
|
| |
|
| | class ConfigTester(unittest.TestCase): |
| | def test_load_not_from_mixin(self): |
| | with self.assertRaises(ValueError): |
| | ConfigMixin.load_config("dummy_path") |
| |
|
| | def test_register_to_config(self): |
| | obj = SampleObject() |
| | config = obj.config |
| | assert config["a"] == 2 |
| | assert config["b"] == 5 |
| | assert config["c"] == (2, 5) |
| | assert config["d"] == "for diffusion" |
| | assert config["e"] == [1, 3] |
| |
|
| | |
| | obj = SampleObject(_name_or_path="lalala") |
| | config = obj.config |
| | assert config["a"] == 2 |
| | assert config["b"] == 5 |
| | assert config["c"] == (2, 5) |
| | assert config["d"] == "for diffusion" |
| | assert config["e"] == [1, 3] |
| |
|
| | |
| | obj = SampleObject(c=6) |
| | config = obj.config |
| | assert config["a"] == 2 |
| | assert config["b"] == 5 |
| | assert config["c"] == 6 |
| | assert config["d"] == "for diffusion" |
| | assert config["e"] == [1, 3] |
| |
|
| | |
| | obj = SampleObject(1, c=6) |
| | config = obj.config |
| | assert config["a"] == 1 |
| | assert config["b"] == 5 |
| | assert config["c"] == 6 |
| | assert config["d"] == "for diffusion" |
| | assert config["e"] == [1, 3] |
| |
|
| | def test_save_load(self): |
| | obj = SampleObject() |
| | config = obj.config |
| |
|
| | assert config["a"] == 2 |
| | assert config["b"] == 5 |
| | assert config["c"] == (2, 5) |
| | assert config["d"] == "for diffusion" |
| | assert config["e"] == [1, 3] |
| |
|
| | with tempfile.TemporaryDirectory() as tmpdirname: |
| | obj.save_config(tmpdirname) |
| | new_obj = SampleObject.from_config(SampleObject.load_config(tmpdirname)) |
| | new_config = new_obj.config |
| |
|
| | |
| | config = dict(config) |
| | new_config = dict(new_config) |
| |
|
| | assert config.pop("c") == (2, 5) |
| | assert new_config.pop("c") == [2, 5] |
| | config.pop("_use_default_values") |
| | assert config == new_config |
| |
|
| | def test_load_ddim_from_pndm(self): |
| | logger = logging.get_logger("diffusers.configuration_utils") |
| | |
| | logger.setLevel(30) |
| |
|
| | with CaptureLogger(logger) as cap_logger: |
| | ddim = DDIMScheduler.from_pretrained( |
| | "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" |
| | ) |
| |
|
| | assert ddim.__class__ == DDIMScheduler |
| | |
| | assert cap_logger.out == "" |
| |
|
| | def test_load_euler_from_pndm(self): |
| | logger = logging.get_logger("diffusers.configuration_utils") |
| | |
| | logger.setLevel(30) |
| |
|
| | with CaptureLogger(logger) as cap_logger: |
| | euler = EulerDiscreteScheduler.from_pretrained( |
| | "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" |
| | ) |
| |
|
| | assert euler.__class__ == EulerDiscreteScheduler |
| | |
| | assert cap_logger.out == "" |
| |
|
| | def test_load_euler_ancestral_from_pndm(self): |
| | logger = logging.get_logger("diffusers.configuration_utils") |
| | |
| | logger.setLevel(30) |
| |
|
| | with CaptureLogger(logger) as cap_logger: |
| | euler = EulerAncestralDiscreteScheduler.from_pretrained( |
| | "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" |
| | ) |
| |
|
| | assert euler.__class__ == EulerAncestralDiscreteScheduler |
| | |
| | assert cap_logger.out == "" |
| |
|
| | def test_load_pndm(self): |
| | logger = logging.get_logger("diffusers.configuration_utils") |
| | |
| | logger.setLevel(30) |
| |
|
| | with CaptureLogger(logger) as cap_logger: |
| | pndm = PNDMScheduler.from_pretrained( |
| | "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" |
| | ) |
| |
|
| | assert pndm.__class__ == PNDMScheduler |
| | |
| | assert cap_logger.out == "" |
| |
|
| | def test_overwrite_config_on_load(self): |
| | logger = logging.get_logger("diffusers.configuration_utils") |
| | |
| | logger.setLevel(30) |
| |
|
| | with CaptureLogger(logger) as cap_logger: |
| | ddpm = DDPMScheduler.from_pretrained( |
| | "hf-internal-testing/tiny-stable-diffusion-torch", |
| | subfolder="scheduler", |
| | prediction_type="sample", |
| | beta_end=8, |
| | ) |
| |
|
| | with CaptureLogger(logger) as cap_logger_2: |
| | ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88) |
| |
|
| | assert ddpm.__class__ == DDPMScheduler |
| | assert ddpm.config.prediction_type == "sample" |
| | assert ddpm.config.beta_end == 8 |
| | assert ddpm_2.config.beta_start == 88 |
| |
|
| | |
| | assert cap_logger.out == "" |
| | assert cap_logger_2.out == "" |
| |
|
| | def test_load_dpmsolver(self): |
| | logger = logging.get_logger("diffusers.configuration_utils") |
| | |
| | logger.setLevel(30) |
| |
|
| | with CaptureLogger(logger) as cap_logger: |
| | dpm = DPMSolverMultistepScheduler.from_pretrained( |
| | "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" |
| | ) |
| |
|
| | assert dpm.__class__ == DPMSolverMultistepScheduler |
| | |
| | assert cap_logger.out == "" |
| |
|
| | def test_use_default_values(self): |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | config = SampleObject() |
| |
|
| | config_dict = {k: v for k, v in config.config.items() if not k.startswith("_")} |
| |
|
| | |
| | assert set(config_dict.keys()) == set(config.config._use_default_values) |
| |
|
| | with tempfile.TemporaryDirectory() as tmpdirname: |
| | config.save_config(tmpdirname) |
| |
|
| | |
| | config = SampleObject2.from_config(SampleObject2.load_config(tmpdirname)) |
| |
|
| | assert "f" in config.config._use_default_values |
| | assert config.config.f == [1, 3] |
| |
|
| | |
| | |
| | new_config = SampleObject4.from_config(config.config) |
| | assert new_config.config.f == [5, 4] |
| |
|
| | config.config._use_default_values.pop() |
| | new_config_2 = SampleObject4.from_config(config.config) |
| | assert new_config_2.config.f == [1, 3] |
| |
|
| | |
| | assert new_config_2.config.e == [1, 3] |
| |
|
| | def test_check_path_types(self): |
| | |
| | config = SampleObjectPaths() |
| | json_string = config.to_json_string() |
| | result = json.loads(json_string) |
| | assert result["test_file_1"] == config.config.test_file_1.as_posix() |
| | assert result["test_file_2"] == config.config.test_file_2.as_posix() |
| |
|