Spaces:
Runtime error
Runtime error
| import os.path | |
| import stat | |
| from collections import OrderedDict | |
| from modules import shared, scripts, sd_models | |
| from modules.paths import models_path | |
| from scripts.enums import StableDiffusionVersion | |
| from scripts.supported_preprocessor import Preprocessor | |
| from typing import Dict, Tuple, List | |
| CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin"] | |
| cn_models_dir = os.path.join(models_path, "ControlNet") | |
| cn_models_dir_old = os.path.join(scripts.basedir(), "models") | |
| cn_models = OrderedDict() # "My_Lora(abcd1234)" -> C:/path/to/model.safetensors | |
| cn_models_names = {} # "my_lora" -> "My_Lora(abcd1234)" | |
| default_detectedmap_dir = os.path.join("detected_maps") | |
| script_dir = scripts.basedir() | |
| os.makedirs(cn_models_dir, exist_ok=True) | |
| def traverse_all_files(curr_path, model_list): | |
| f_list = [ | |
| (os.path.join(curr_path, entry.name), entry.stat()) | |
| for entry in os.scandir(curr_path) | |
| if os.path.isdir(curr_path) | |
| ] | |
| for f_info in f_list: | |
| fname, fstat = f_info | |
| if os.path.splitext(fname)[1] in CN_MODEL_EXTS: | |
| model_list.append(f_info) | |
| elif stat.S_ISDIR(fstat.st_mode): | |
| model_list = traverse_all_files(fname, model_list) | |
| return model_list | |
| def get_all_models(sort_by, filter_by, path): | |
| res = OrderedDict() | |
| fileinfos = traverse_all_files(path, []) | |
| filter_by = filter_by.strip(" ") | |
| if len(filter_by) != 0: | |
| fileinfos = [x for x in fileinfos if filter_by.lower() | |
| in os.path.basename(x[0]).lower()] | |
| if sort_by == "name": | |
| fileinfos = sorted(fileinfos, key=lambda x: os.path.basename(x[0])) | |
| elif sort_by == "date": | |
| fileinfos = sorted(fileinfos, key=lambda x: -x[1].st_mtime) | |
| elif sort_by == "path name": | |
| fileinfos = sorted(fileinfos) | |
| for finfo in fileinfos: | |
| filename = finfo[0] | |
| name = os.path.splitext(os.path.basename(filename))[0] | |
| # Prevent a hypothetical "None.pt" from being listed. | |
| if name != "None": | |
| res[name + f" [{sd_models.model_hash(filename)}]"] = filename | |
| return res | |
| def update_cn_models(): | |
| cn_models.clear() | |
| ext_dirs = (shared.opts.data.get("control_net_models_path", None), getattr(shared.cmd_opts, 'controlnet_dir', None)) | |
| extra_lora_paths = (extra_lora_path for extra_lora_path in ext_dirs | |
| if extra_lora_path is not None and os.path.exists(extra_lora_path)) | |
| paths = [cn_models_dir, cn_models_dir_old, *extra_lora_paths] | |
| for path in paths: | |
| sort_by = shared.opts.data.get( | |
| "control_net_models_sort_models_by", "name") | |
| filter_by = shared.opts.data.get("control_net_models_name_filter", "") | |
| found = get_all_models(sort_by, filter_by, path) | |
| cn_models.update({**found, **cn_models}) | |
| # insert "None" at the beginning of `cn_models` in-place | |
| cn_models_copy = OrderedDict(cn_models) | |
| cn_models.clear() | |
| cn_models.update({**{"None": None}, **cn_models_copy}) | |
| cn_models_names.clear() | |
| for name_and_hash, filename in cn_models.items(): | |
| if filename is None: | |
| continue | |
| name = os.path.splitext(os.path.basename(filename))[0].lower() | |
| cn_models_names[name] = name_and_hash | |
| def get_sd_version() -> StableDiffusionVersion: | |
| if hasattr(shared.sd_model, 'is_sdxl'): | |
| if shared.sd_model.is_sdxl: | |
| return StableDiffusionVersion.SDXL | |
| elif shared.sd_model.is_sd2: | |
| return StableDiffusionVersion.SD2x | |
| elif shared.sd_model.is_sd1: | |
| return StableDiffusionVersion.SD1x | |
| else: | |
| return StableDiffusionVersion.UNKNOWN | |
| # backward compability for webui < 1.5.0 | |
| else: | |
| if hasattr(shared.sd_model, 'conditioner'): | |
| return StableDiffusionVersion.SDXL | |
| elif hasattr(shared.sd_model.cond_stage_model, 'model'): | |
| return StableDiffusionVersion.SD2x | |
| else: | |
| return StableDiffusionVersion.SD1x | |
| def select_control_type( | |
| control_type: str, | |
| sd_version: StableDiffusionVersion = StableDiffusionVersion.UNKNOWN, | |
| cn_models: Dict = cn_models, # Override or testing | |
| ) -> Tuple[List[str], List[str], str, str]: | |
| pattern = control_type.lower() | |
| all_models = list(cn_models.keys()) | |
| if pattern == "all": | |
| return [ | |
| [p.label for p in Preprocessor.get_sorted_preprocessors()], | |
| all_models, | |
| 'none', #default option | |
| "None" #default model | |
| ] | |
| filtered_model_list = [ | |
| model for model in all_models | |
| if model.lower() == "none" or | |
| (( | |
| pattern in model.lower() or | |
| any(a in model.lower() for a in Preprocessor.tag_to_filters(control_type)) | |
| ) and ( | |
| sd_version.is_compatible_with(StableDiffusionVersion.detect_from_model_name(model)) | |
| )) | |
| ] | |
| assert len(filtered_model_list) > 0, "'None' model should always be available." | |
| if len(filtered_model_list) == 1: | |
| default_model = "None" | |
| else: | |
| default_model = filtered_model_list[1] | |
| for x in filtered_model_list: | |
| if "11" in x.split("[")[0]: | |
| default_model = x | |
| break | |
| return ( | |
| [p.label for p in Preprocessor.get_filtered_preprocessors(control_type)], | |
| filtered_model_list, | |
| Preprocessor.get_default_preprocessor(control_type).label, | |
| default_model | |
| ) | |