| from pathlib import Path |
|
|
| import pytest |
|
|
| from invokeai.backend.model_manager.search import ModelSearch |
|
|
|
|
| @pytest.fixture |
| def model_search(tmp_path: Path) -> tuple[ModelSearch, Path]: |
| search = ModelSearch() |
| return search, tmp_path |
|
|
|
|
| def test_model_search_on_search_started(model_search: tuple[ModelSearch, Path]): |
| search, tmp_path = model_search |
| on_search_started_called_with: Path | None = None |
|
|
| def on_search_started_callback(path: Path) -> None: |
| nonlocal on_search_started_called_with |
| on_search_started_called_with = path |
|
|
| search.on_search_started = on_search_started_callback |
| search.search(tmp_path) |
|
|
| assert on_search_started_called_with == tmp_path |
|
|
|
|
| def test_model_search_on_completed(model_search: tuple[ModelSearch, Path]): |
| search, tmp_path = model_search |
| on_search_completed_called_with: set[Path] | None = None |
| file1 = tmp_path / "file1.ckpt" |
| with open(file1, "w") as f: |
| f.write("") |
|
|
| def on_search_completed_callback(models: set[Path]) -> None: |
| nonlocal on_search_completed_called_with |
| on_search_completed_called_with = models |
|
|
| search.on_search_completed = on_search_completed_callback |
| expected = {file1} |
| found = search.search(tmp_path) |
|
|
| assert found == expected |
| assert on_search_completed_called_with == expected |
|
|
|
|
| def test_model_search_handles_files(model_search: tuple[ModelSearch, Path]): |
| search, tmp_path = model_search |
| on_model_found_called_with: set[Path] = set() |
|
|
| file1 = tmp_path / "file1.ckpt" |
| file2 = tmp_path / "file2.ckpt" |
| file3 = tmp_path / "subfolder" / "file3.ckpt" |
| file4 = tmp_path / "subfolder" / "subfolder" / "file4.ckpt" |
| file5 = tmp_path / "not_a_model_file.txt" |
|
|
| file4.parent.mkdir(parents=True) |
| for file in [file1, file2, file3, file4, file5]: |
| with open(file, "w") as f: |
| f.write("") |
|
|
| def on_model_found_callback(path: Path) -> bool: |
| on_model_found_called_with.add(path) |
| return True |
|
|
| search.on_model_found = on_model_found_callback |
|
|
| expected = {file1, file2, file3, file4} |
| found = search.search(tmp_path) |
|
|
| assert on_model_found_called_with == expected |
| assert found == expected |
| assert search.stats.models_found == 4 |
| assert search.stats.models_filtered == 4 |
|
|
|
|
| def test_model_search_filters_by_on_model_found(model_search: tuple[ModelSearch, Path]): |
| search, tmp_path = model_search |
| on_model_found_called_with: set[Path] = set() |
|
|
| file1 = tmp_path / "file1.ckpt" |
| file2 = tmp_path / "file2.ckpt" |
|
|
| for file in [file1, file2]: |
| with open(file, "w") as f: |
| f.write("") |
|
|
| def on_model_found_callback(path: Path) -> bool: |
| if path == file2: |
| return False |
| on_model_found_called_with.add(path) |
| return True |
|
|
| search.on_model_found = on_model_found_callback |
|
|
| expected = {file1} |
| found = search.search(tmp_path) |
|
|
| assert on_model_found_called_with == expected |
| assert found == expected |
| assert search.stats.models_filtered == 1 |
| assert search.stats.models_found == 2 |
|
|
|
|
| def test_model_search_handles_diffusers_model_dirs(model_search: tuple[ModelSearch, Path]): |
| search, tmp_path = model_search |
| on_model_found_called_with: set[Path] = set() |
|
|
| diffusers_dir = tmp_path / "diffusers_dir" |
| diffusers_dir_entry_point = diffusers_dir / "model_index.json" |
| diffusers_dir.mkdir() |
| with open(diffusers_dir_entry_point, "w") as f: |
| f.write("") |
|
|
| nested_diffusers_dir = tmp_path / "subfolder" / "nested_diffusers_dir" |
| nested_diffusers_dir_entry_point = nested_diffusers_dir / "model_index.json" |
| nested_diffusers_dir_ignore_me_file = nested_diffusers_dir / "ignore_me.ckpt" |
| nested_diffusers_dir.mkdir(parents=True) |
| with open(nested_diffusers_dir_entry_point, "w") as f: |
| f.write("") |
| with open(nested_diffusers_dir_ignore_me_file, "w") as f: |
| f.write("") |
|
|
| not_a_diffusers_dir = tmp_path / "not_a_diffusers_dir" |
| not_a_diffusers_dir_entry_point = not_a_diffusers_dir / "not_model_index.json" |
| not_a_diffusers_dir.mkdir() |
| with open(not_a_diffusers_dir_entry_point, "w") as f: |
| f.write("") |
|
|
| def on_model_found_callback(path: Path) -> bool: |
| on_model_found_called_with.add(path) |
| return True |
|
|
| search.on_model_found = on_model_found_callback |
|
|
| expected = {diffusers_dir, nested_diffusers_dir} |
| found = search.search(tmp_path) |
|
|
| assert found == expected |
| assert on_model_found_called_with == expected |
| assert search.stats.models_found == 2 |
| assert search.stats.models_filtered == 2 |
|
|