| | import os |
| | import os.path |
| | import re |
| | import shutil |
| | import json |
| | import stat |
| | import tqdm |
| | from collections import OrderedDict |
| | from multiprocessing.pool import ThreadPool as Pool |
| |
|
| | from modules import shared, sd_models, hashes |
| | from scripts import safetensors_hack, model_util, util |
| | import modules.scripts as scripts |
| |
|
| |
|
| | |
| | MAX_MODEL_COUNT = shared.cmd_opts.addnet_max_model_count if hasattr(shared.cmd_opts, "addnet_max_model_count") else 5 |
| | LORA_MODEL_EXTS = [".pt", ".ckpt", ".safetensors"] |
| | re_legacy_hash = re.compile("\(([0-9a-f]{8})\)$") |
| | lora_models = {} |
| | lora_model_names = {} |
| | legacy_model_names = {} |
| | lora_models_dir = os.path.join(scripts.basedir(), "models/lora") |
| | os.makedirs(lora_models_dir, exist_ok=True) |
| |
|
| |
|
| | def is_safetensors(filename): |
| | return os.path.splitext(filename)[1] == ".safetensors" |
| |
|
| |
|
| | def read_model_metadata(model_path, module): |
| | if model_path.startswith("\"") and model_path.endswith("\""): |
| | model_path = model_path[1:-1] |
| | if not os.path.exists(model_path): |
| | return None |
| |
|
| | metadata = None |
| | if module == "LoRA": |
| | if os.path.splitext(model_path)[1] == '.safetensors': |
| | metadata = safetensors_hack.read_metadata(model_path) |
| |
|
| | return metadata |
| |
|
| |
|
| | def write_model_metadata(model_path, module, updates): |
| | if model_path.startswith("\"") and model_path.endswith("\""): |
| | model_path = model_path[1:-1] |
| | if not os.path.exists(model_path): |
| | return None |
| |
|
| | from safetensors.torch import save_file |
| |
|
| | back_up = shared.opts.data.get("additional_networks_back_up_model_when_saving", True) |
| | if back_up: |
| | backup_path = model_path + ".backup" |
| | if not os.path.exists(backup_path): |
| | print(f"[MetadataEditor] Backing up current model to {backup_path}") |
| | shutil.copyfile(model_path, backup_path) |
| |
|
| | metadata = None |
| | tensors = {} |
| | if module == "LoRA": |
| | if os.path.splitext(model_path)[1] == '.safetensors': |
| | tensors, metadata = safetensors_hack.load_file(model_path, "cpu") |
| |
|
| | for k, v in updates.items(): |
| | metadata[k] = str(v) |
| |
|
| | save_file(tensors, model_path, metadata) |
| | print(f"[MetadataEditor] Model saved: {model_path}") |
| |
|
| |
|
| | def get_model_list(module, model, model_dir, sort_by): |
| | if model_dir == "": |
| | |
| | model_path = lora_models.get(model, None) |
| | if model_path is None: |
| | return [] |
| | model_dir = os.path.dirname(model_path) |
| |
|
| | if not os.path.isdir(model_dir): |
| | return [] |
| |
|
| | found, _ = get_all_models([model_dir], sort_by, "") |
| | return found.keys() |
| |
|
| |
|
| | def traverse_all_files(curr_path, model_list): |
| | f_list = [(os.path.join(curr_path, entry.name), entry.stat()) for entry in os.scandir(curr_path)] |
| | for f_info in f_list: |
| | fname, fstat = f_info |
| | if os.path.splitext(fname)[1] in LORA_MODEL_EXTS: |
| | model_list.append(f_info) |
| | elif stat.S_ISDIR(fstat.st_mode): |
| | model_list = traverse_all_files(fname, model_list) |
| | return model_list |
| |
|
| |
|
| | def get_model_hash(metadata, filename): |
| | if metadata is None: |
| | return hashes.calculate_sha256(filename) |
| |
|
| | if "sshs_model_hash" in metadata: |
| | return metadata["sshs_model_hash"] |
| |
|
| | return safetensors_hack.hash_file(filename) |
| |
|
| |
|
| | def get_legacy_hash(metadata, filename): |
| | if metadata is None: |
| | return sd_models.model_hash(filename) |
| |
|
| | if "sshs_legacy_hash" in metadata: |
| | return metadata["sshs_legacy_hash"] |
| |
|
| | return safetensors_hack.legacy_hash_file(filename) |
| |
|
| |
|
| | import filelock |
| | cache_filename = os.path.join(scripts.basedir(), "hashes.json") |
| | cache_data = None |
| |
|
| |
|
| | def cache(subsection): |
| | global cache_data |
| |
|
| | if cache_data is None: |
| | with filelock.FileLock(cache_filename+".lock"): |
| | if not os.path.isfile(cache_filename): |
| | cache_data = {} |
| | else: |
| | with open(cache_filename, "r", encoding="utf8") as file: |
| | cache_data = json.load(file) |
| |
|
| | s = cache_data.get(subsection, {}) |
| | cache_data[subsection] = s |
| |
|
| | return s |
| |
|
| |
|
| | def dump_cache(): |
| | with filelock.FileLock(cache_filename+".lock"): |
| | with open(cache_filename, "w", encoding="utf8") as file: |
| | json.dump(cache_data, file, indent=4) |
| |
|
| |
|
| | def get_model_rating(filename): |
| | if not model_util.is_safetensors(filename): |
| | return 0 |
| |
|
| | metadata = safetensors_hack.read_metadata(filename) |
| | return int(metadata.get("ssmd_rating", "0")) |
| |
|
| |
|
| | def has_user_metadata(filename): |
| | if not model_util.is_safetensors(filename): |
| | return False |
| |
|
| | metadata = safetensors_hack.read_metadata(filename) |
| | return any(k.startswith("ssmd_") for k in metadata.keys()) |
| |
|
| |
|
| | def hash_model_file(finfo): |
| | filename = finfo[0] |
| | stat = finfo[1] |
| | name = os.path.splitext(os.path.basename(filename))[0] |
| |
|
| | |
| | if name != "None": |
| | metadata = None |
| |
|
| | cached = cache("hashes").get(filename, None) |
| | if cached is None or stat.st_mtime != cached["mtime"]: |
| | if metadata is None and model_util.is_safetensors(filename): |
| | try: |
| | metadata = safetensors_hack.read_metadata(filename) |
| | except Exception as ex: |
| | return {"error": ex, "filename": filename} |
| | model_hash = get_model_hash(metadata, filename) |
| | legacy_hash = get_legacy_hash(metadata, filename) |
| | else: |
| | model_hash = cached["model"] |
| | legacy_hash = cached["legacy"] |
| |
|
| | return {"model": model_hash, "legacy": legacy_hash, "fileinfo": finfo} |
| |
|
| |
|
| | def get_all_models(paths, sort_by, filter_by): |
| | fileinfos = [] |
| | for path in paths: |
| | if os.path.isdir(path): |
| | fileinfos += traverse_all_files(path, []) |
| |
|
| | show_only_safetensors = shared.opts.data.get("additional_networks_show_only_safetensors", False) |
| | show_only_missing_meta = shared.opts.data.get("additional_networks_show_only_models_with_metadata", "disabled") |
| |
|
| | if show_only_safetensors: |
| | fileinfos = [x for x in fileinfos if is_safetensors(x[0])] |
| |
|
| | if show_only_missing_meta == "has metadata": |
| | fileinfos = [x for x in fileinfos if has_user_metadata(x[0])] |
| | elif show_only_missing_meta == "missing metadata": |
| | fileinfos = [x for x in fileinfos if not has_user_metadata(x[0])] |
| |
|
| | print("[AddNet] Updating model hashes...") |
| | data = [] |
| | thread_count = max(1, int(shared.opts.data.get("additional_networks_hash_thread_count", 1))) |
| | p = Pool(processes=thread_count) |
| | with tqdm.tqdm(total=len(fileinfos)) as pbar: |
| | for res in p.imap_unordered(hash_model_file, fileinfos): |
| | pbar.update() |
| | if "error" in res: |
| | print(f"Failed to read model file {res['filename']}: {res['error']}") |
| | else: |
| | data.append(res) |
| | p.close() |
| |
|
| | cache_hashes = cache("hashes") |
| |
|
| | res = OrderedDict() |
| | res_legacy = OrderedDict() |
| | filter_by = filter_by.strip(" ") |
| | if len(filter_by) != 0: |
| | data = [x for x in data if filter_by.lower() in os.path.basename(x["fileinfo"][0]).lower()] |
| | if sort_by == "name": |
| | data = sorted(data, key=lambda x: os.path.basename(x["fileinfo"][0])) |
| | elif sort_by == "date": |
| | data = sorted(data, key=lambda x: -x["fileinfo"][1].st_mtime) |
| | elif sort_by == "path name": |
| | data = sorted(data, key=lambda x: x["fileinfo"][0]) |
| | elif sort_by == "rating": |
| | data = sorted(data, key=lambda x: get_model_rating(x["fileinfo"][0]), reverse=True) |
| | elif sort_by == "has user metadata": |
| | data = sorted(data, key=lambda x: os.path.basename(x["fileinfo"][0]) if has_user_metadata(x["fileinfo"][0]) else "", reverse=True) |
| |
|
| | reverse = shared.opts.data.get("additional_networks_reverse_sort_order", False) |
| | if reverse: |
| | data = reversed(data) |
| |
|
| | for result in data: |
| | finfo = result["fileinfo"] |
| | filename = finfo[0] |
| | stat = finfo[1] |
| | model_hash = result["model"] |
| | legacy_hash = result["legacy"] |
| |
|
| | name = os.path.splitext(os.path.basename(filename))[0] |
| |
|
| | |
| | |
| | name = name.replace(",", "_") |
| |
|
| | |
| | if name != "None": |
| | full_name = name + f"({model_hash[0:12]})" |
| | res[full_name] = filename |
| | res_legacy[legacy_hash] = full_name |
| | cache_hashes[filename] = {"model": model_hash, "legacy": legacy_hash, "mtime": stat.st_mtime} |
| |
|
| | return res, res_legacy |
| |
|
| |
|
| | def find_closest_lora_model_name(search: str): |
| | if not search or search == "None": |
| | return None |
| |
|
| | |
| | |
| | if search in lora_models: |
| | return search |
| |
|
| | |
| | |
| | if os.path.isfile(search): |
| | import json |
| | find = os.path.normpath(search) |
| | value = next((k for k in lora_models.keys() if lora_models[k] == find), None) |
| | if value: |
| | return value |
| |
|
| | search = search.lower() |
| |
|
| | |
| | |
| | if search in lora_model_names: |
| | return lora_model_names.get(search) |
| |
|
| | |
| | |
| | result = re_legacy_hash.search(search) |
| | if result is not None: |
| | model_hash = result.group(1) |
| | if model_hash in legacy_model_names: |
| | new_model_name = legacy_model_names[model_hash] |
| | return new_model_name |
| |
|
| | |
| | |
| | |
| | applicable = [name for name in lora_model_names.keys() if search in name.lower()] |
| | if not applicable: |
| | return None |
| | applicable = sorted(applicable, key=lambda name: len(name)) |
| | return lora_model_names[applicable[0]] |
| |
|
| |
|
| | def update_models(): |
| | global lora_models, lora_model_names, legacy_model_names |
| | paths = [lora_models_dir] |
| | extra_lora_paths = util.split_path_list(shared.opts.data.get("additional_networks_extra_lora_path", "")) |
| | for path in extra_lora_paths: |
| | path = path.lstrip() |
| | if os.path.isdir(path): |
| | paths.append(path) |
| |
|
| | sort_by = shared.opts.data.get("additional_networks_sort_models_by", "name") |
| | filter_by = shared.opts.data.get("additional_networks_model_name_filter", "") |
| | res, res_legacy = get_all_models(paths, sort_by, filter_by) |
| |
|
| | lora_models.clear() |
| | lora_models["None"] = None |
| | lora_models.update(res) |
| |
|
| | for name_and_hash, filename in lora_models.items(): |
| | if filename == None: |
| | continue |
| | name = os.path.splitext(os.path.basename(filename))[0].lower() |
| | lora_model_names[name] = name_and_hash |
| |
|
| | legacy_model_names = res_legacy |
| | dump_cache() |
| |
|
| |
|
| | update_models() |
| |
|