File size: 4,532 Bytes
7bed60d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from __future__ import annotations

import importlib
import sys
from functools import lru_cache
from pathlib import Path

from modules import extensions, sd_models, shared
from modules.paths import extensions_builtin_dir, extensions_dir, models_path

from .common import cn_model_module, cn_model_regex

ext_path = Path(extensions_dir)
ext_builtin_path = Path(extensions_builtin_dir)
controlnet_exists = False
controlnet_type = "standard"
controlnet_path = None
cn_base_path = ""

for extension in extensions.active():
    if not extension.enabled:
        continue
    # For cases like sd-webui-controlnet-master
    if "sd-webui-controlnet" in extension.name:
        controlnet_exists = True
        controlnet_path = Path(extension.path)
        cn_base_path = ".".join(controlnet_path.parts[-2:])
        break

if controlnet_path is not None:
    sd_webui_controlnet_path = controlnet_path.resolve().parent
    if sd_webui_controlnet_path.stem in ("extensions", "extensions-builtin"):
        target_path = str(sd_webui_controlnet_path.parent)
        if target_path not in sys.path:
            sys.path.append(target_path)


class ControlNetExt:
    def __init__(self):
        self.cn_models = ["None"]
        self.cn_available = False
        self.external_cn = None

    def init_controlnet(self):
        import_path = cn_base_path + ".scripts.external_code"

        self.external_cn = importlib.import_module(import_path, "external_code")
        self.cn_available = True
        models = self.external_cn.get_models()
        self.cn_models.extend(m for m in models if cn_model_regex.search(m))

    def update_scripts_args(
        self,
        p,
        model: str,
        module: str | None,
        weight: float,
        guidance_start: float,
        guidance_end: float,
    ):
        if (not self.cn_available) or model == "None":
            return

        if module == "None":
            module = None
        if module is None:
            for m, v in cn_model_module.items():
                if m in model:
                    module = v
                    break

        cn_units = [
            self.external_cn.ControlNetUnit(
                model=model,
                weight=weight,
                control_mode=self.external_cn.ControlMode.BALANCED,
                module=module,
                guidance_start=guidance_start,
                guidance_end=guidance_end,
                pixel_perfect=True,
            )
        ]

        try:
            self.external_cn.update_cn_script_in_processing(p, cn_units)
        except AttributeError as e:
            if "script_args_value" not in str(e):
                raise
            msg = "[-] Adetailer: ControlNet option not available in WEBUI version lower than 1.6.0 due to updates in ControlNet"
            raise RuntimeError(msg) from e


def get_cn_model_dirs() -> list[Path]:
    cn_model_dir = Path(models_path, "ControlNet")
    if controlnet_path is not None:
        cn_model_dir_old = controlnet_path.joinpath("models")
    else:
        cn_model_dir_old = None
    ext_dir1 = shared.opts.data.get("control_net_models_path", "")
    ext_dir2 = getattr(shared.cmd_opts, "controlnet_dir", "")

    dirs = [cn_model_dir]
    dirs += [
        Path(ext_dir) for ext_dir in [cn_model_dir_old, ext_dir1, ext_dir2] if ext_dir
    ]

    return dirs


@lru_cache
def _get_cn_models() -> list[str]:
    """
    Since we can't import ControlNet, we use a function that does something like
    controlnet's `list(global_state.cn_models_names.values())`.
    """
    cn_model_exts = (".pt", ".pth", ".ckpt", ".safetensors")
    dirs = get_cn_model_dirs()
    name_filter = shared.opts.data.get("control_net_models_name_filter", "")
    name_filter = name_filter.strip(" ").lower()

    model_paths = []

    for base in dirs:
        if not base.exists():
            continue

        for p in base.rglob("*"):
            if (
                p.is_file()
                and p.suffix in cn_model_exts
                and cn_model_regex.search(p.name)
            ):
                if name_filter and name_filter not in p.name.lower():
                    continue
                model_paths.append(p)
    model_paths.sort(key=lambda p: p.name)

    models = []
    for p in model_paths:
        model_hash = sd_models.model_hash(p)
        name = f"{p.stem} [{model_hash}]"
        models.append(name)
    return models


def get_cn_models() -> list[str]:
    if controlnet_exists:
        return _get_cn_models()
    return []