| import os.path |
| import stat |
| from collections import OrderedDict |
|
|
| from modules import shared, sd_models |
| from lib_controlnet.enums import StableDiffusionVersion |
| from modules_forge.shared import controlnet_dir, supported_preprocessors |
|
|
| from typing import Dict, Tuple, List |
|
|
| CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin", ".patch"] |
|
|
|
|
| def traverse_all_files(curr_path, model_list): |
| f_list = [ |
| (os.path.join(curr_path, entry.name), entry.stat()) |
| for entry in os.scandir(curr_path) |
| if os.path.isdir(curr_path) |
| ] |
| for f_info in f_list: |
| fname, fstat = f_info |
| if os.path.splitext(fname)[1] in CN_MODEL_EXTS: |
| model_list.append(f_info) |
| elif stat.S_ISDIR(fstat.st_mode): |
| model_list = traverse_all_files(fname, model_list) |
| return model_list |
|
|
|
|
| def get_all_models(sort_by, filter_by, path): |
| res = OrderedDict() |
| fileinfos = traverse_all_files(path, []) |
| filter_by = filter_by.strip(" ") |
| if len(filter_by) != 0: |
| fileinfos = [x for x in fileinfos if filter_by.lower() |
| in os.path.basename(x[0]).lower()] |
| if sort_by == "name": |
| fileinfos = sorted(fileinfos, key=lambda x: os.path.basename(x[0])) |
| elif sort_by == "date": |
| fileinfos = sorted(fileinfos, key=lambda x: -x[1].st_mtime) |
| elif sort_by == "path name": |
| fileinfos = sorted(fileinfos) |
|
|
| for finfo in fileinfos: |
| filename = finfo[0] |
| name = os.path.splitext(os.path.basename(filename))[0] |
| |
| if name != "None": |
| res[name + f" [{sd_models.model_hash(filename)}]"] = filename |
|
|
| return res |
|
|
|
|
| controlnet_filename_dict = {'None': 'model.safetensors'} |
| controlnet_names = ['None'] |
|
|
|
|
| def get_preprocessor(name): |
| return supported_preprocessors.get(name, None) |
|
|
| def get_default_preprocessor(tag): |
| ps = get_filtered_preprocessor_names(tag) |
| assert len(ps) > 0 |
| return ps[0] if len(ps) == 1 else ps[1] |
|
|
| def get_sorted_preprocessors(): |
| preprocessors = [p for k, p in supported_preprocessors.items() if k != 'None'] |
| preprocessors = sorted(preprocessors, key=lambda x: str(x.sorting_priority).zfill(8) + x.name)[::-1] |
| results = OrderedDict() |
| results['None'] = supported_preprocessors['None'] |
| for p in preprocessors: |
| results[p.name] = p |
| return results |
|
|
|
|
| def get_all_controlnet_names(): |
| return controlnet_names |
|
|
|
|
| def get_controlnet_filename(controlnet_name): |
| return controlnet_filename_dict[controlnet_name] |
|
|
|
|
| def get_all_preprocessor_names(): |
| return list(get_sorted_preprocessors().keys()) |
|
|
|
|
| def get_all_preprocessor_tags(): |
| tags = [] |
| for k, p in supported_preprocessors.items(): |
| tags += p.tags |
| tags = list(set(tags)) |
| tags = sorted(tags) |
| return ['All'] + tags |
|
|
|
|
| def get_filtered_preprocessors(tag): |
| if tag == 'All': |
| return supported_preprocessors |
| return {k: v for k, v in get_sorted_preprocessors().items() if tag in v.tags or k == 'None'} |
|
|
|
|
| def get_filtered_preprocessor_names(tag): |
| return list(get_filtered_preprocessors(tag).keys()) |
|
|
|
|
| def get_filtered_controlnet_names(tag): |
| filtered_preprocessors = get_filtered_preprocessors(tag) |
| model_filename_filters = [] |
| for p in filtered_preprocessors.values(): |
| model_filename_filters += p.model_filename_filters |
| return [ |
| x for x in controlnet_names |
| if x == 'None' or ( |
| any(f.lower() in x.lower() for f in model_filename_filters) and |
| get_sd_version().is_compatible_with(StableDiffusionVersion.detect_from_model_name(x)) |
| ) |
| ] |
|
|
|
|
| def update_controlnet_filenames(): |
| global controlnet_filename_dict, controlnet_names |
|
|
| controlnet_filename_dict = {'None': 'model.safetensors'} |
| controlnet_names = ['None'] |
|
|
| ext_dirs = (shared.opts.data.get("control_net_models_path", None), getattr(shared.cmd_opts, 'controlnet_dir', None)) |
| extra_lora_paths = (extra_lora_path for extra_lora_path in ext_dirs |
| if extra_lora_path is not None and os.path.exists(extra_lora_path)) |
| paths = [controlnet_dir, *extra_lora_paths] |
|
|
| for path in paths: |
| sort_by = shared.opts.data.get("control_net_models_sort_models_by", "name") |
| filter_by = shared.opts.data.get("control_net_models_name_filter", "") |
| found = get_all_models(sort_by, filter_by, path) |
| controlnet_filename_dict.update(found) |
|
|
| controlnet_names = list(controlnet_filename_dict.keys()) |
| return |
|
|
|
|
| def get_sd_version() -> StableDiffusionVersion: |
| if not shared.sd_model: |
| return StableDiffusionVersion.UNKNOWN |
| if shared.sd_model.is_sdxl: |
| return StableDiffusionVersion.SDXL |
| elif shared.sd_model.is_sd2: |
| return StableDiffusionVersion.SD2x |
| elif shared.sd_model.is_sd1: |
| return StableDiffusionVersion.SD1x |
| else: |
| return StableDiffusionVersion.UNKNOWN |
|
|
|
|
| def select_control_type( |
| control_type: str, |
| sd_version: StableDiffusionVersion = StableDiffusionVersion.UNKNOWN, |
| ) -> Tuple[List[str], List[str], str, str]: |
| global controlnet_names |
|
|
| pattern = control_type.lower() |
| all_models = list(controlnet_names) |
|
|
| if pattern == "all": |
| preprocessors = get_sorted_preprocessors().values() |
| return [ |
| [p.name for p in preprocessors], |
| all_models, |
| 'none', |
| "None" |
| ] |
|
|
| filtered_model_list = get_filtered_controlnet_names(control_type) |
|
|
| if pattern == "none": |
| filtered_model_list.append("None") |
|
|
| assert len(filtered_model_list) > 0, "'None' model should always be available." |
| if len(filtered_model_list) == 1: |
| default_model = "None" |
| else: |
| default_model = filtered_model_list[1] |
| for x in filtered_model_list: |
| if "11" in x.split("[")[0]: |
| default_model = x |
| break |
|
|
| return ( |
| get_filtered_preprocessor_names(control_type), |
| filtered_model_list, |
| get_default_preprocessor(control_type), |
| default_model |
| ) |
|
|