| | import gc |
| | import tempfile |
| | from io import BytesIO |
| |
|
| | import requests |
| | import torch |
| | from huggingface_hub import hf_hub_download, snapshot_download |
| |
|
| | from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name |
| | from diffusers.models.attention_processor import AttnProcessor |
| |
|
| | from ..testing_utils import ( |
| | backend_empty_cache, |
| | nightly, |
| | numpy_cosine_similarity_distance, |
| | require_torch_accelerator, |
| | torch_device, |
| | ) |
| |
|
| |
|
| | def download_single_file_checkpoint(repo_id, filename, tmpdir): |
| | path = hf_hub_download(repo_id, filename=filename, local_dir=tmpdir) |
| | return path |
| |
|
| |
|
| | def download_original_config(config_url, tmpdir): |
| | original_config_file = BytesIO(requests.get(config_url).content) |
| | path = f"{tmpdir}/config.yaml" |
| | with open(path, "wb") as f: |
| | f.write(original_config_file.read()) |
| |
|
| | return path |
| |
|
| |
|
| | def download_diffusers_config(repo_id, tmpdir): |
| | path = snapshot_download( |
| | repo_id, |
| | ignore_patterns=[ |
| | "**/*.ckpt", |
| | "*.ckpt", |
| | "**/*.bin", |
| | "*.bin", |
| | "**/*.pt", |
| | "*.pt", |
| | "**/*.safetensors", |
| | "*.safetensors", |
| | ], |
| | allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt"], |
| | local_dir=tmpdir, |
| | ) |
| | return path |
| |
|
| |
|
| | @nightly |
| | @require_torch_accelerator |
| | class SingleFileModelTesterMixin: |
| | def setup_method(self): |
| | gc.collect() |
| | backend_empty_cache(torch_device) |
| |
|
| | def teardown_method(self): |
| | gc.collect() |
| | backend_empty_cache(torch_device) |
| |
|
| | def test_single_file_model_config(self): |
| | pretrained_kwargs = {} |
| | single_file_kwargs = {} |
| |
|
| | if hasattr(self, "subfolder") and self.subfolder: |
| | pretrained_kwargs["subfolder"] = self.subfolder |
| |
|
| | if hasattr(self, "torch_dtype") and self.torch_dtype: |
| | pretrained_kwargs["torch_dtype"] = self.torch_dtype |
| | single_file_kwargs["torch_dtype"] = self.torch_dtype |
| |
|
| | model = self.model_class.from_pretrained(self.repo_id, **pretrained_kwargs) |
| | model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs) |
| |
|
| | PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"] |
| | for param_name, param_value in model_single_file.config.items(): |
| | if param_name in PARAMS_TO_IGNORE: |
| | continue |
| | assert model.config[param_name] == param_value, ( |
| | f"{param_name} differs between pretrained loading and single file loading" |
| | ) |
| |
|
| | def test_single_file_model_parameters(self): |
| | pretrained_kwargs = {} |
| | single_file_kwargs = {} |
| |
|
| | if hasattr(self, "subfolder") and self.subfolder: |
| | pretrained_kwargs["subfolder"] = self.subfolder |
| |
|
| | if hasattr(self, "torch_dtype") and self.torch_dtype: |
| | pretrained_kwargs["torch_dtype"] = self.torch_dtype |
| | single_file_kwargs["torch_dtype"] = self.torch_dtype |
| |
|
| | model = self.model_class.from_pretrained(self.repo_id, **pretrained_kwargs) |
| | model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs) |
| |
|
| | state_dict = model.state_dict() |
| | state_dict_single_file = model_single_file.state_dict() |
| |
|
| | assert set(state_dict.keys()) == set(state_dict_single_file.keys()), ( |
| | "Model parameters keys differ between pretrained and single file loading" |
| | ) |
| |
|
| | for key in state_dict.keys(): |
| | param = state_dict[key] |
| | param_single_file = state_dict_single_file[key] |
| |
|
| | assert param.shape == param_single_file.shape, ( |
| | f"Parameter shape mismatch for {key}: " |
| | f"pretrained {param.shape} vs single file {param_single_file.shape}" |
| | ) |
| |
|
| | assert torch.allclose(param, param_single_file, rtol=1e-5, atol=1e-5), ( |
| | f"Parameter values differ for {key}: " |
| | f"max difference {torch.max(torch.abs(param - param_single_file)).item()}" |
| | ) |
| |
|
| | def test_checkpoint_altered_keys_loading(self): |
| | |
| | if not hasattr(self, "alternate_keys_ckpt_paths") or not self.alternate_keys_ckpt_paths: |
| | return |
| |
|
| | for ckpt_path in self.alternate_keys_ckpt_paths: |
| | backend_empty_cache(torch_device) |
| |
|
| | single_file_kwargs = {} |
| | if hasattr(self, "torch_dtype") and self.torch_dtype: |
| | single_file_kwargs["torch_dtype"] = self.torch_dtype |
| |
|
| | model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs) |
| |
|
| | del model |
| | gc.collect() |
| | backend_empty_cache(torch_device) |
| |
|
| |
|
| | class SDSingleFileTesterMixin: |
| | single_file_kwargs = {} |
| |
|
| | def _compare_component_configs(self, pipe, single_file_pipe): |
| | for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items(): |
| | if param_name in ["torch_dtype", "architectures", "_name_or_path"]: |
| | continue |
| | assert pipe.text_encoder.config.to_dict()[param_name] == param_value |
| |
|
| | PARAMS_TO_IGNORE = [ |
| | "torch_dtype", |
| | "_name_or_path", |
| | "architectures", |
| | "_use_default_values", |
| | "_diffusers_version", |
| | ] |
| | for component_name, component in single_file_pipe.components.items(): |
| | if component_name in single_file_pipe._optional_components: |
| | continue |
| |
|
| | |
| | |
| | if component_name in ["text_encoder", "tokenizer", "safety_checker", "feature_extractor"]: |
| | continue |
| |
|
| | assert component_name in pipe.components, f"single file {component_name} not found in pretrained pipeline" |
| | assert isinstance(component, pipe.components[component_name].__class__), ( |
| | f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same" |
| | ) |
| |
|
| | for param_name, param_value in component.config.items(): |
| | if param_name in PARAMS_TO_IGNORE: |
| | continue |
| |
|
| | |
| | |
| | if param_name == "upcast_attention" and pipe.components[component_name].config[param_name] is None: |
| | pipe.components[component_name].config[param_name] = param_value |
| |
|
| | assert pipe.components[component_name].config[param_name] == param_value, ( |
| | f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}" |
| | ) |
| |
|
| | def test_single_file_components(self, pipe=None, single_file_pipe=None): |
| | single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( |
| | self.ckpt_path, safety_checker=None |
| | ) |
| | pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) |
| |
|
| | self._compare_component_configs(pipe, single_file_pipe) |
| |
|
| | def test_single_file_components_local_files_only(self, pipe=None, single_file_pipe=None): |
| | pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) |
| |
|
| | with tempfile.TemporaryDirectory() as tmpdir: |
| | repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) |
| | local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) |
| |
|
| | single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( |
| | local_ckpt_path, safety_checker=None, local_files_only=True |
| | ) |
| |
|
| | self._compare_component_configs(pipe, single_file_pipe) |
| |
|
| | def test_single_file_components_with_original_config( |
| | self, |
| | pipe=None, |
| | single_file_pipe=None, |
| | ): |
| | pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) |
| | |
| | |
| | upcast_attention = pipe.unet.config.upcast_attention |
| |
|
| | single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( |
| | self.ckpt_path, |
| | original_config=self.original_config, |
| | safety_checker=None, |
| | upcast_attention=upcast_attention, |
| | ) |
| |
|
| | self._compare_component_configs(pipe, single_file_pipe) |
| |
|
| | def test_single_file_components_with_original_config_local_files_only( |
| | self, |
| | pipe=None, |
| | single_file_pipe=None, |
| | ): |
| | pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) |
| |
|
| | |
| | |
| | upcast_attention = pipe.unet.config.upcast_attention |
| |
|
| | with tempfile.TemporaryDirectory() as tmpdir: |
| | repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) |
| | local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) |
| | local_original_config = download_original_config(self.original_config, tmpdir) |
| |
|
| | single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( |
| | local_ckpt_path, |
| | original_config=local_original_config, |
| | safety_checker=None, |
| | upcast_attention=upcast_attention, |
| | local_files_only=True, |
| | ) |
| |
|
| | self._compare_component_configs(pipe, single_file_pipe) |
| |
|
| | def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_diff=1e-4): |
| | sf_pipe = self.pipeline_class.from_single_file(self.ckpt_path, safety_checker=None, **self.single_file_kwargs) |
| | sf_pipe.unet.set_attn_processor(AttnProcessor()) |
| | sf_pipe.enable_model_cpu_offload(device=torch_device) |
| |
|
| | inputs = self.get_inputs(torch_device) |
| | image_single_file = sf_pipe(**inputs).images[0] |
| |
|
| | pipe = self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) |
| | pipe.unet.set_attn_processor(AttnProcessor()) |
| | pipe.enable_model_cpu_offload(device=torch_device) |
| |
|
| | inputs = self.get_inputs(torch_device) |
| | image = pipe(**inputs).images[0] |
| |
|
| | max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten()) |
| |
|
| | assert max_diff < expected_max_diff, f"{image.flatten()} != {image_single_file.flatten()}" |
| |
|
| | def test_single_file_components_with_diffusers_config( |
| | self, |
| | pipe=None, |
| | single_file_pipe=None, |
| | ): |
| | single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( |
| | self.ckpt_path, config=self.repo_id, safety_checker=None |
| | ) |
| | pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) |
| |
|
| | self._compare_component_configs(pipe, single_file_pipe) |
| |
|
| | def test_single_file_components_with_diffusers_config_local_files_only( |
| | self, |
| | pipe=None, |
| | single_file_pipe=None, |
| | ): |
| | pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) |
| |
|
| | with tempfile.TemporaryDirectory() as tmpdir: |
| | repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) |
| | local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) |
| | local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) |
| |
|
| | single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( |
| | local_ckpt_path, config=local_diffusers_config, safety_checker=None, local_files_only=True |
| | ) |
| |
|
| | self._compare_component_configs(pipe, single_file_pipe) |
| |
|
| | def test_single_file_setting_pipeline_dtype_to_fp16( |
| | self, |
| | single_file_pipe=None, |
| | ): |
| | single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( |
| | self.ckpt_path, torch_dtype=torch.float16 |
| | ) |
| |
|
| | for component_name, component in single_file_pipe.components.items(): |
| | if not isinstance(component, torch.nn.Module): |
| | continue |
| |
|
| | assert component.dtype == torch.float16 |
| |
|
| |
|
| | class SDXLSingleFileTesterMixin: |
| | def _compare_component_configs(self, pipe, single_file_pipe): |
| | |
| | if pipe.text_encoder: |
| | for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items(): |
| | if param_name in ["torch_dtype", "architectures", "_name_or_path"]: |
| | continue |
| | assert pipe.text_encoder.config.to_dict()[param_name] == param_value |
| |
|
| | for param_name, param_value in single_file_pipe.text_encoder_2.config.to_dict().items(): |
| | if param_name in ["torch_dtype", "architectures", "_name_or_path"]: |
| | continue |
| | assert pipe.text_encoder_2.config.to_dict()[param_name] == param_value |
| |
|
| | PARAMS_TO_IGNORE = [ |
| | "torch_dtype", |
| | "_name_or_path", |
| | "architectures", |
| | "_use_default_values", |
| | "_diffusers_version", |
| | ] |
| | for component_name, component in single_file_pipe.components.items(): |
| | if component_name in single_file_pipe._optional_components: |
| | continue |
| |
|
| | |
| | if component_name in ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]: |
| | continue |
| |
|
| | |
| | if component_name in ["safety_checker", "feature_extractor"]: |
| | continue |
| |
|
| | assert component_name in pipe.components, f"single file {component_name} not found in pretrained pipeline" |
| | assert isinstance(component, pipe.components[component_name].__class__), ( |
| | f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same" |
| | ) |
| |
|
| | for param_name, param_value in component.config.items(): |
| | if param_name in PARAMS_TO_IGNORE: |
| | continue |
| |
|
| | |
| | |
| | if param_name == "upcast_attention" and pipe.components[component_name].config[param_name] is None: |
| | pipe.components[component_name].config[param_name] = param_value |
| |
|
| | assert pipe.components[component_name].config[param_name] == param_value, ( |
| | f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}" |
| | ) |
| |
|
| | def test_single_file_components(self, pipe=None, single_file_pipe=None): |
| | single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( |
| | self.ckpt_path, safety_checker=None |
| | ) |
| | pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) |
| |
|
| | self._compare_component_configs( |
| | pipe, |
| | single_file_pipe, |
| | ) |
| |
|
| | def test_single_file_components_local_files_only( |
| | self, |
| | pipe=None, |
| | single_file_pipe=None, |
| | ): |
| | pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) |
| |
|
| | with tempfile.TemporaryDirectory() as tmpdir: |
| | repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) |
| | local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) |
| |
|
| | single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( |
| | local_ckpt_path, safety_checker=None, local_files_only=True |
| | ) |
| |
|
| | self._compare_component_configs(pipe, single_file_pipe) |
| |
|
| | def test_single_file_components_with_original_config( |
| | self, |
| | pipe=None, |
| | single_file_pipe=None, |
| | ): |
| | pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) |
| | |
| | |
| | upcast_attention = pipe.unet.config.upcast_attention |
| | single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( |
| | self.ckpt_path, |
| | original_config=self.original_config, |
| | safety_checker=None, |
| | upcast_attention=upcast_attention, |
| | ) |
| |
|
| | self._compare_component_configs( |
| | pipe, |
| | single_file_pipe, |
| | ) |
| |
|
| | def test_single_file_components_with_original_config_local_files_only( |
| | self, |
| | pipe=None, |
| | single_file_pipe=None, |
| | ): |
| | pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) |
| | |
| | |
| | upcast_attention = pipe.unet.config.upcast_attention |
| |
|
| | with tempfile.TemporaryDirectory() as tmpdir: |
| | repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) |
| | local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) |
| | local_original_config = download_original_config(self.original_config, tmpdir) |
| |
|
| | single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( |
| | local_ckpt_path, |
| | original_config=local_original_config, |
| | upcast_attention=upcast_attention, |
| | safety_checker=None, |
| | local_files_only=True, |
| | ) |
| |
|
| | self._compare_component_configs( |
| | pipe, |
| | single_file_pipe, |
| | ) |
| |
|
| | def test_single_file_components_with_diffusers_config( |
| | self, |
| | pipe=None, |
| | single_file_pipe=None, |
| | ): |
| | single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( |
| | self.ckpt_path, config=self.repo_id, safety_checker=None |
| | ) |
| | pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) |
| |
|
| | self._compare_component_configs(pipe, single_file_pipe) |
| |
|
| | def test_single_file_components_with_diffusers_config_local_files_only( |
| | self, |
| | pipe=None, |
| | single_file_pipe=None, |
| | ): |
| | pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) |
| |
|
| | with tempfile.TemporaryDirectory() as tmpdir: |
| | repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) |
| | local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) |
| | local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) |
| |
|
| | single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( |
| | local_ckpt_path, config=local_diffusers_config, safety_checker=None, local_files_only=True |
| | ) |
| |
|
| | self._compare_component_configs(pipe, single_file_pipe) |
| |
|
| | def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_diff=1e-4): |
| | sf_pipe = self.pipeline_class.from_single_file(self.ckpt_path, torch_dtype=torch.float16, safety_checker=None) |
| | sf_pipe.unet.set_default_attn_processor() |
| | sf_pipe.enable_model_cpu_offload(device=torch_device) |
| |
|
| | inputs = self.get_inputs(torch_device) |
| | image_single_file = sf_pipe(**inputs).images[0] |
| |
|
| | pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16, safety_checker=None) |
| | pipe.unet.set_default_attn_processor() |
| | pipe.enable_model_cpu_offload(device=torch_device) |
| |
|
| | inputs = self.get_inputs(torch_device) |
| | image = pipe(**inputs).images[0] |
| |
|
| | max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten()) |
| |
|
| | assert max_diff < expected_max_diff |
| |
|
| | def test_single_file_setting_pipeline_dtype_to_fp16( |
| | self, |
| | single_file_pipe=None, |
| | ): |
| | single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( |
| | self.ckpt_path, torch_dtype=torch.float16 |
| | ) |
| |
|
| | for component_name, component in single_file_pipe.components.items(): |
| | if not isinstance(component, torch.nn.Module): |
| | continue |
| |
|
| | assert component.dtype == torch.float16 |
| |
|