| from pathlib import Path |
| from typing import List |
|
|
| import pytest |
|
|
| from invokeai.backend.model_manager.config import ModelRepoVariant |
| from invokeai.backend.model_manager.util.select_hf_files import filter_files |
|
|
|
|
| |
| @pytest.fixture |
| def sdxl_base_files() -> List[Path]: |
| return [ |
| Path(x) |
| for x in [ |
| ".gitattributes", |
| "01.png", |
| "LICENSE.md", |
| "README.md", |
| "comparison.png", |
| "model_index.json", |
| "pipeline.png", |
| "scheduler/scheduler_config.json", |
| "sd_xl_base_1.0.safetensors", |
| "sd_xl_base_1.0_0.9vae.safetensors", |
| "sd_xl_offset_example-lora_1.0.safetensors", |
| "text_encoder/config.json", |
| "text_encoder/flax_model.msgpack", |
| "text_encoder/model.fp16.safetensors", |
| "text_encoder/model.onnx", |
| "text_encoder/model.safetensors", |
| "text_encoder/openvino_model.bin", |
| "text_encoder/openvino_model.xml", |
| "text_encoder_2/config.json", |
| "text_encoder_2/flax_model.msgpack", |
| "text_encoder_2/model.fp16.safetensors", |
| "text_encoder_2/model.onnx", |
| "text_encoder_2/model.onnx_data", |
| "text_encoder_2/model.safetensors", |
| "text_encoder_2/openvino_model.bin", |
| "text_encoder_2/openvino_model.xml", |
| "tokenizer/merges.txt", |
| "tokenizer/special_tokens_map.json", |
| "tokenizer/tokenizer_config.json", |
| "tokenizer/vocab.json", |
| "tokenizer_2/merges.txt", |
| "tokenizer_2/special_tokens_map.json", |
| "tokenizer_2/tokenizer_config.json", |
| "tokenizer_2/vocab.json", |
| "unet/config.json", |
| "unet/diffusion_flax_model.msgpack", |
| "unet/diffusion_pytorch_model.fp16.safetensors", |
| "unet/diffusion_pytorch_model.safetensors", |
| "unet/model.onnx", |
| "unet/model.onnx_data", |
| "unet/openvino_model.bin", |
| "unet/openvino_model.xml", |
| "vae/config.json", |
| "vae/diffusion_flax_model.msgpack", |
| "vae/diffusion_pytorch_model.fp16.safetensors", |
| "vae/diffusion_pytorch_model.safetensors", |
| "vae_1_0/config.json", |
| "vae_1_0/diffusion_pytorch_model.fp16.safetensors", |
| "vae_1_0/diffusion_pytorch_model.safetensors", |
| "vae_decoder/config.json", |
| "vae_decoder/model.onnx", |
| "vae_decoder/openvino_model.bin", |
| "vae_decoder/openvino_model.xml", |
| "vae_encoder/config.json", |
| "vae_encoder/model.onnx", |
| "vae_encoder/openvino_model.bin", |
| "vae_encoder/openvino_model.xml", |
| ] |
| ] |
|
|
|
|
| |
| @pytest.mark.parametrize( |
| "variant,expected_list", |
| [ |
| ( |
| None, |
| [ |
| "model_index.json", |
| "scheduler/scheduler_config.json", |
| "text_encoder/config.json", |
| "text_encoder/model.safetensors", |
| "text_encoder_2/config.json", |
| "text_encoder_2/model.safetensors", |
| "tokenizer/merges.txt", |
| "tokenizer/special_tokens_map.json", |
| "tokenizer/tokenizer_config.json", |
| "tokenizer/vocab.json", |
| "tokenizer_2/merges.txt", |
| "tokenizer_2/special_tokens_map.json", |
| "tokenizer_2/tokenizer_config.json", |
| "tokenizer_2/vocab.json", |
| "unet/config.json", |
| "unet/diffusion_pytorch_model.safetensors", |
| "vae/config.json", |
| "vae/diffusion_pytorch_model.safetensors", |
| "vae_1_0/config.json", |
| "vae_1_0/diffusion_pytorch_model.safetensors", |
| ], |
| ), |
| ( |
| ModelRepoVariant.Default, |
| [ |
| "model_index.json", |
| "scheduler/scheduler_config.json", |
| "text_encoder/config.json", |
| "text_encoder/model.safetensors", |
| "text_encoder_2/config.json", |
| "text_encoder_2/model.safetensors", |
| "tokenizer/merges.txt", |
| "tokenizer/special_tokens_map.json", |
| "tokenizer/tokenizer_config.json", |
| "tokenizer/vocab.json", |
| "tokenizer_2/merges.txt", |
| "tokenizer_2/special_tokens_map.json", |
| "tokenizer_2/tokenizer_config.json", |
| "tokenizer_2/vocab.json", |
| "unet/config.json", |
| "unet/diffusion_pytorch_model.safetensors", |
| "vae/config.json", |
| "vae/diffusion_pytorch_model.safetensors", |
| "vae_1_0/config.json", |
| "vae_1_0/diffusion_pytorch_model.safetensors", |
| ], |
| ), |
| ( |
| ModelRepoVariant.OpenVINO, |
| [ |
| "model_index.json", |
| "scheduler/scheduler_config.json", |
| "text_encoder/config.json", |
| "text_encoder/openvino_model.bin", |
| "text_encoder/openvino_model.xml", |
| "text_encoder_2/config.json", |
| "text_encoder_2/openvino_model.bin", |
| "text_encoder_2/openvino_model.xml", |
| "tokenizer/merges.txt", |
| "tokenizer/special_tokens_map.json", |
| "tokenizer/tokenizer_config.json", |
| "tokenizer/vocab.json", |
| "tokenizer_2/merges.txt", |
| "tokenizer_2/special_tokens_map.json", |
| "tokenizer_2/tokenizer_config.json", |
| "tokenizer_2/vocab.json", |
| "unet/config.json", |
| "unet/openvino_model.bin", |
| "unet/openvino_model.xml", |
| "vae_decoder/config.json", |
| "vae_decoder/openvino_model.bin", |
| "vae_decoder/openvino_model.xml", |
| "vae_encoder/config.json", |
| "vae_encoder/openvino_model.bin", |
| "vae_encoder/openvino_model.xml", |
| ], |
| ), |
| ( |
| ModelRepoVariant.FP16, |
| [ |
| "model_index.json", |
| "scheduler/scheduler_config.json", |
| "text_encoder/config.json", |
| "text_encoder/model.fp16.safetensors", |
| "text_encoder_2/config.json", |
| "text_encoder_2/model.fp16.safetensors", |
| "tokenizer/merges.txt", |
| "tokenizer/special_tokens_map.json", |
| "tokenizer/tokenizer_config.json", |
| "tokenizer/vocab.json", |
| "tokenizer_2/merges.txt", |
| "tokenizer_2/special_tokens_map.json", |
| "tokenizer_2/tokenizer_config.json", |
| "tokenizer_2/vocab.json", |
| "unet/config.json", |
| "unet/diffusion_pytorch_model.fp16.safetensors", |
| "vae/config.json", |
| "vae/diffusion_pytorch_model.fp16.safetensors", |
| "vae_1_0/config.json", |
| "vae_1_0/diffusion_pytorch_model.fp16.safetensors", |
| ], |
| ), |
| ( |
| ModelRepoVariant.ONNX, |
| [ |
| "model_index.json", |
| "scheduler/scheduler_config.json", |
| "text_encoder/config.json", |
| "text_encoder/model.onnx", |
| "text_encoder_2/config.json", |
| "text_encoder_2/model.onnx", |
| "text_encoder_2/model.onnx_data", |
| "tokenizer/merges.txt", |
| "tokenizer/special_tokens_map.json", |
| "tokenizer/tokenizer_config.json", |
| "tokenizer/vocab.json", |
| "tokenizer_2/merges.txt", |
| "tokenizer_2/special_tokens_map.json", |
| "tokenizer_2/tokenizer_config.json", |
| "tokenizer_2/vocab.json", |
| "unet/config.json", |
| "unet/model.onnx", |
| "unet/model.onnx_data", |
| "vae_decoder/config.json", |
| "vae_decoder/model.onnx", |
| "vae_encoder/config.json", |
| "vae_encoder/model.onnx", |
| ], |
| ), |
| ( |
| ModelRepoVariant.Flax, |
| [ |
| "model_index.json", |
| "scheduler/scheduler_config.json", |
| "text_encoder/config.json", |
| "text_encoder/flax_model.msgpack", |
| "text_encoder_2/config.json", |
| "text_encoder_2/flax_model.msgpack", |
| "tokenizer/merges.txt", |
| "tokenizer/special_tokens_map.json", |
| "tokenizer/tokenizer_config.json", |
| "tokenizer/vocab.json", |
| "tokenizer_2/merges.txt", |
| "tokenizer_2/special_tokens_map.json", |
| "tokenizer_2/tokenizer_config.json", |
| "tokenizer_2/vocab.json", |
| "unet/config.json", |
| "unet/diffusion_flax_model.msgpack", |
| "vae/config.json", |
| "vae/diffusion_flax_model.msgpack", |
| ], |
| ), |
| ], |
| ) |
| def test_select(sdxl_base_files: List[Path], variant: ModelRepoVariant, expected_list: List[str]) -> None: |
| print(f"testing variant {variant}") |
| filtered_files = filter_files(sdxl_base_files, variant) |
| assert set(filtered_files) == {Path(x) for x in expected_list} |
|
|
|
|
| @pytest.fixture |
| def sd15_test_files() -> list[Path]: |
| return [ |
| Path(f) |
| for f in [ |
| "feature_extractor/preprocessor_config.json", |
| "safety_checker/config.json", |
| "safety_checker/model.fp16.safetensors", |
| "safety_checker/model.safetensors", |
| "safety_checker/pytorch_model.bin", |
| "safety_checker/pytorch_model.fp16.bin", |
| "scheduler/scheduler_config.json", |
| "text_encoder/config.json", |
| "text_encoder/model.fp16.safetensors", |
| "text_encoder/model.safetensors", |
| "text_encoder/pytorch_model.bin", |
| "text_encoder/pytorch_model.fp16.bin", |
| "tokenizer/merges.txt", |
| "tokenizer/special_tokens_map.json", |
| "tokenizer/tokenizer_config.json", |
| "tokenizer/vocab.json", |
| "unet/config.json", |
| "unet/diffusion_pytorch_model.bin", |
| "unet/diffusion_pytorch_model.fp16.bin", |
| "unet/diffusion_pytorch_model.fp16.safetensors", |
| "unet/diffusion_pytorch_model.non_ema.bin", |
| "unet/diffusion_pytorch_model.non_ema.safetensors", |
| "unet/diffusion_pytorch_model.safetensors", |
| "vae/config.json", |
| "vae/diffusion_pytorch_model.bin", |
| "vae/diffusion_pytorch_model.fp16.bin", |
| "vae/diffusion_pytorch_model.fp16.safetensors", |
| "vae/diffusion_pytorch_model.safetensors", |
| ] |
| ] |
|
|
|
|
| @pytest.mark.parametrize( |
| "variant,expected_files", |
| [ |
| ( |
| ModelRepoVariant.FP16, |
| [ |
| "feature_extractor/preprocessor_config.json", |
| "safety_checker/config.json", |
| "safety_checker/model.fp16.safetensors", |
| "scheduler/scheduler_config.json", |
| "text_encoder/config.json", |
| "text_encoder/model.fp16.safetensors", |
| "tokenizer/merges.txt", |
| "tokenizer/special_tokens_map.json", |
| "tokenizer/tokenizer_config.json", |
| "tokenizer/vocab.json", |
| "unet/config.json", |
| "unet/diffusion_pytorch_model.fp16.safetensors", |
| "vae/config.json", |
| "vae/diffusion_pytorch_model.fp16.safetensors", |
| ], |
| ), |
| ( |
| ModelRepoVariant.FP32, |
| [ |
| "feature_extractor/preprocessor_config.json", |
| "safety_checker/config.json", |
| "safety_checker/model.safetensors", |
| "scheduler/scheduler_config.json", |
| "text_encoder/config.json", |
| "text_encoder/model.safetensors", |
| "tokenizer/merges.txt", |
| "tokenizer/special_tokens_map.json", |
| "tokenizer/tokenizer_config.json", |
| "tokenizer/vocab.json", |
| "unet/config.json", |
| "unet/diffusion_pytorch_model.safetensors", |
| "vae/config.json", |
| "vae/diffusion_pytorch_model.safetensors", |
| ], |
| ), |
| ], |
| ) |
| def test_select_multiple_weights( |
| sd15_test_files: list[Path], variant: ModelRepoVariant, expected_files: list[str] |
| ) -> None: |
| filtered_files = filter_files(sd15_test_files, variant) |
| assert set(filtered_files) == {Path(f) for f in expected_files} |
|
|
|
|
| @pytest.fixture |
| def flux_schnell_test_files() -> list[Path]: |
| return [ |
| Path(f) |
| for f in [ |
| "FLUX.1-schnell/.gitattributes", |
| "FLUX.1-schnell/README.md", |
| "FLUX.1-schnell/ae.safetensors", |
| "FLUX.1-schnell/flux1-schnell.safetensors", |
| "FLUX.1-schnell/model_index.json", |
| "FLUX.1-schnell/scheduler/scheduler_config.json", |
| "FLUX.1-schnell/schnell_grid.jpeg", |
| "FLUX.1-schnell/text_encoder/config.json", |
| "FLUX.1-schnell/text_encoder/model.safetensors", |
| "FLUX.1-schnell/text_encoder_2/config.json", |
| "FLUX.1-schnell/text_encoder_2/model-00001-of-00002.safetensors", |
| "FLUX.1-schnell/text_encoder_2/model-00002-of-00002.safetensors", |
| "FLUX.1-schnell/text_encoder_2/model.safetensors.index.json", |
| "FLUX.1-schnell/tokenizer/merges.txt", |
| "FLUX.1-schnell/tokenizer/special_tokens_map.json", |
| "FLUX.1-schnell/tokenizer/tokenizer_config.json", |
| "FLUX.1-schnell/tokenizer/vocab.json", |
| "FLUX.1-schnell/tokenizer_2/special_tokens_map.json", |
| "FLUX.1-schnell/tokenizer_2/spiece.model", |
| "FLUX.1-schnell/tokenizer_2/tokenizer.json", |
| "FLUX.1-schnell/tokenizer_2/tokenizer_config.json", |
| "FLUX.1-schnell/transformer/config.json", |
| "FLUX.1-schnell/transformer/diffusion_pytorch_model-00001-of-00003.safetensors", |
| "FLUX.1-schnell/transformer/diffusion_pytorch_model-00002-of-00003.safetensors", |
| "FLUX.1-schnell/transformer/diffusion_pytorch_model-00003-of-00003.safetensors", |
| "FLUX.1-schnell/transformer/diffusion_pytorch_model.safetensors.index.json", |
| "FLUX.1-schnell/vae/config.json", |
| "FLUX.1-schnell/vae/diffusion_pytorch_model.safetensors", |
| ] |
| ] |
|
|
|
|
| @pytest.mark.parametrize( |
| ["variant", "expected_files"], |
| [ |
| ( |
| ModelRepoVariant.Default, |
| [ |
| "FLUX.1-schnell/model_index.json", |
| "FLUX.1-schnell/scheduler/scheduler_config.json", |
| "FLUX.1-schnell/text_encoder/config.json", |
| "FLUX.1-schnell/text_encoder/model.safetensors", |
| "FLUX.1-schnell/text_encoder_2/config.json", |
| "FLUX.1-schnell/text_encoder_2/model-00001-of-00002.safetensors", |
| "FLUX.1-schnell/text_encoder_2/model-00002-of-00002.safetensors", |
| "FLUX.1-schnell/text_encoder_2/model.safetensors.index.json", |
| "FLUX.1-schnell/tokenizer/merges.txt", |
| "FLUX.1-schnell/tokenizer/special_tokens_map.json", |
| "FLUX.1-schnell/tokenizer/tokenizer_config.json", |
| "FLUX.1-schnell/tokenizer/vocab.json", |
| "FLUX.1-schnell/tokenizer_2/special_tokens_map.json", |
| "FLUX.1-schnell/tokenizer_2/spiece.model", |
| "FLUX.1-schnell/tokenizer_2/tokenizer.json", |
| "FLUX.1-schnell/tokenizer_2/tokenizer_config.json", |
| "FLUX.1-schnell/transformer/config.json", |
| "FLUX.1-schnell/transformer/diffusion_pytorch_model-00001-of-00003.safetensors", |
| "FLUX.1-schnell/transformer/diffusion_pytorch_model-00002-of-00003.safetensors", |
| "FLUX.1-schnell/transformer/diffusion_pytorch_model-00003-of-00003.safetensors", |
| "FLUX.1-schnell/transformer/diffusion_pytorch_model.safetensors.index.json", |
| "FLUX.1-schnell/vae/config.json", |
| "FLUX.1-schnell/vae/diffusion_pytorch_model.safetensors", |
| ], |
| ), |
| ], |
| ) |
| def test_select_flux_schnell_files( |
| flux_schnell_test_files: list[Path], variant: ModelRepoVariant, expected_files: list[str] |
| ) -> None: |
| filtered_files = filter_files(flux_schnell_test_files, variant) |
| assert set(filtered_files) == {Path(f) for f in expected_files} |
|
|