|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from omegaconf import DictConfig, OmegaConf |
|
|
from modules import scripts, script_callbacks |
|
|
|
|
|
CONFIG_PATH = Path(__file__).with_suffix(".yaml") |
|
|
PRESETS_PATH = Path(__file__).with_name(Path(__file__).stem + ".presets.yaml") |
|
|
|
|
|
|
|
|
|
|
|
RESOLUTION_GROUPS = { |
|
|
"Квадрат": [(1024, 1024)], |
|
|
"Портрет": [(640, 1536), (768, 1344), (832, 1216), (896, 1152)], |
|
|
"Альбом": [(1536, 640), (1344, 768), (1216, 832), (1152, 896)], |
|
|
} |
|
|
RESOLUTION_CHOICES: List[str] = ["— не применять —"] |
|
|
for group, dims in RESOLUTION_GROUPS.items(): |
|
|
for w, h in dims: |
|
|
RESOLUTION_CHOICES.append(f"{group}: {w}x{h}") |
|
|
|
|
|
|
|
|
def parse_resolution_label(label: str) -> Optional[Tuple[int, int]]: |
|
|
if not label or label.startswith("—"): |
|
|
return None |
|
|
try: |
|
|
_, wh = label.split(":") |
|
|
w, h = wh.strip().lower().split("x") |
|
|
return int(w), int(h) |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _safe_mode(mode: str) -> str: |
|
|
if mode == "nearest-exact": |
|
|
return mode |
|
|
if mode in {"bicubic", "bilinear", "nearest"}: |
|
|
return mode |
|
|
return "bilinear" |
|
|
|
|
|
|
|
|
def _load_yaml(path: Path, default: dict) -> dict: |
|
|
try: |
|
|
return OmegaConf.to_container(OmegaConf.load(path), resolve=True) or default |
|
|
except Exception: |
|
|
return default |
|
|
|
|
|
|
|
|
def _atomic_save_yaml(path: Path, data: dict) -> None: |
|
|
try: |
|
|
tmp = path.with_suffix(path.suffix + ".tmp") |
|
|
OmegaConf.save(DictConfig(data), tmp) |
|
|
tmp.replace(path) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
def _load_presets() -> Dict[str, dict]: |
|
|
data = _load_yaml(PRESETS_PATH, {}) |
|
|
return {str(k): dict(v) for k, v in data.items()} |
|
|
|
|
|
|
|
|
def _save_presets(presets: Dict[str, dict]) -> None: |
|
|
_atomic_save_yaml(PRESETS_PATH, presets) |
|
|
|
|
|
|
|
|
def _clamp(x: float, lo: float, hi: float) -> float: |
|
|
return float(max(lo, min(hi, x))) |
|
|
|
|
|
|
|
|
def _norm_mode_choice(value: str, default_: str = "false") -> str: |
|
|
"""Привести выбор из UI к {'true','false','auto'}.""" |
|
|
s = str(value or "").strip().lower() |
|
|
if s in ("true",): |
|
|
return "true" |
|
|
if s in ("false",): |
|
|
return "false" |
|
|
if s in ("авто", "auto"): |
|
|
return "auto" |
|
|
return default_ |
|
|
|
|
|
|
|
|
def _compute_adaptive_params( |
|
|
width: int, |
|
|
height: int, |
|
|
profile: str, |
|
|
base_s1: float, |
|
|
base_s2: float, |
|
|
base_d1: int, |
|
|
base_d2: int, |
|
|
base_down: float, |
|
|
base_up: float, |
|
|
keep_unitary_product: bool, |
|
|
) -> Tuple[float, float, int, int, float, float]: |
|
|
"""Адаптировать (s1, s2, d1, d2, downscale, upscale) под MPix и аспект.""" |
|
|
rel_mpx = (max(1, int(width)) * max(1, int(height))) / float(1024 * 1024) |
|
|
aspect = max(width, height) / float(max(1, min(width, height))) |
|
|
|
|
|
s_add = 0.0 |
|
|
d_add = 0 |
|
|
down = float(base_down) |
|
|
|
|
|
|
|
|
if rel_mpx >= 1.5: |
|
|
s_add += 0.08 |
|
|
down -= 0.10 |
|
|
elif rel_mpx >= 1.1: |
|
|
s_add += 0.05 |
|
|
down -= 0.05 |
|
|
elif rel_mpx <= 0.8: |
|
|
s_add -= 0.02 |
|
|
down += 0.05 |
|
|
|
|
|
|
|
|
if aspect >= 1.6: |
|
|
d_add += 1 |
|
|
down -= 0.05 |
|
|
if aspect >= 2.0: |
|
|
d_add += 1 |
|
|
s_add += 0.02 |
|
|
|
|
|
|
|
|
prof = (profile or "Сбалансированный").strip().lower() |
|
|
if "консер" in prof: |
|
|
s_add *= 0.6 |
|
|
down = 0.5 + 0.5 * (down - 0.5) |
|
|
elif "агресс" in prof: |
|
|
s_add *= 1.3 |
|
|
down -= 0.05 |
|
|
|
|
|
s1 = _clamp(base_s1 + s_add, 0.0, 0.5) |
|
|
s2 = _clamp(base_s2 + s_add, 0.0, 0.5) |
|
|
d1 = max(1, min(10, int(base_d1 + d_add))) |
|
|
d2 = max(1, min(10, int(base_d2 + d_add))) |
|
|
down = _clamp(down, 0.3, 0.9) |
|
|
|
|
|
if keep_unitary_product: |
|
|
up = 1.0 / max(1e-6, down) |
|
|
else: |
|
|
up = float(base_up) |
|
|
|
|
|
return s1, s2, d1, d2, down, up |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Scaler(torch.nn.Module): |
|
|
"""Обёртка блока U-Net: масштабировать вход, вызвать исходный модуль.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
scale: float, |
|
|
block: torch.nn.Module, |
|
|
scaler: str, |
|
|
align_mode: str = "false", |
|
|
recompute_mode: str = "false", |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.scale: float = float(scale) |
|
|
self.block: torch.nn.Module = block |
|
|
self.scaler: str = _safe_mode(scaler) |
|
|
self.align_mode: str = _norm_mode_choice(align_mode, "false") |
|
|
self.recompute_mode: str = _norm_mode_choice(recompute_mode, "false") |
|
|
|
|
|
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: |
|
|
mode = self.scaler |
|
|
try: |
|
|
kw = dict(scale_factor=self.scale, mode=mode) |
|
|
|
|
|
if mode in ("bilinear", "bicubic"): |
|
|
if self.align_mode == "true": |
|
|
kw["align_corners"] = True |
|
|
elif self.align_mode == "false": |
|
|
kw["align_corners"] = False |
|
|
|
|
|
|
|
|
|
|
|
if self.recompute_mode == "true": |
|
|
kw["recompute_scale_factor"] = True |
|
|
elif self.recompute_mode == "false": |
|
|
kw["recompute_scale_factor"] = False |
|
|
|
|
|
|
|
|
x = F.interpolate(x, **kw) |
|
|
|
|
|
except Exception: |
|
|
|
|
|
safe = "nearest" if mode == "nearest-exact" else "bilinear" |
|
|
kw = dict(scale_factor=self.scale, mode=safe) |
|
|
if safe in ("bilinear", "bicubic"): |
|
|
if self.align_mode == "true": |
|
|
kw["align_corners"] = True |
|
|
elif self.align_mode == "false": |
|
|
kw["align_corners"] = False |
|
|
if self.recompute_mode == "true": |
|
|
kw["recompute_scale_factor"] = True |
|
|
elif self.recompute_mode == "false": |
|
|
kw["recompute_scale_factor"] = False |
|
|
x = F.interpolate(x, **kw) |
|
|
|
|
|
return self.block(x, *args, **kwargs) |
|
|
|
|
|
|
|
|
class KohyaHiresFix(scripts.Script): |
|
|
"""Динамический hires.fix через временную смену масштаба внутренних фич U-Net.""" |
|
|
|
|
|
def __init__(self) -> None: |
|
|
super().__init__() |
|
|
self.config: DictConfig = DictConfig(_load_yaml(CONFIG_PATH, {})) |
|
|
self.disable: bool = False |
|
|
self.step_limit: int = 0 |
|
|
self.infotext_fields = [] |
|
|
self._cb_registered: bool = False |
|
|
|
|
|
self.p1: Tuple[float, int] = (0.15, 2) |
|
|
self.p2: Tuple[float, int] = (0.30, 3) |
|
|
|
|
|
def title(self) -> str: |
|
|
return "Kohya Hires.fix · Русская версия" |
|
|
|
|
|
def show(self, is_img2img: bool): |
|
|
return scripts.AlwaysVisible |
|
|
|
|
|
def ui(self, is_img2img: bool): |
|
|
|
|
|
self.infotext_fields = [] |
|
|
presets = _load_presets() |
|
|
|
|
|
with gr.Accordion(label="Kohya Hires.fix", open=False): |
|
|
enable = gr.Checkbox(label="Включить расширение", value=False) |
|
|
|
|
|
|
|
|
with gr.Group(): |
|
|
gr.Markdown("**Предустановленные разрешения**") |
|
|
with gr.Row(): |
|
|
resolution_choice = gr.Dropdown( |
|
|
choices=RESOLUTION_CHOICES, |
|
|
value=self.config.get("resolution_choice", RESOLUTION_CHOICES[0]), |
|
|
label="Выбрать разрешение", |
|
|
) |
|
|
apply_resolution = gr.Checkbox( |
|
|
label="Применять выбранное разрешение к ширине/высоте", |
|
|
value=self.config.get("apply_resolution", False), |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Group(): |
|
|
gr.Markdown("**Параметры масштабирования**") |
|
|
with gr.Row(): |
|
|
s1 = gr.Slider(0.0, 0.5, step=0.01, label="Остановить на (доля шага) — Пара 1", |
|
|
value=self.config.get("s1", 0.15)) |
|
|
d1 = gr.Slider(1, 10, step=1, label="Глубина блока — Пара 1", |
|
|
value=self.config.get("d1", 3)) |
|
|
with gr.Row(): |
|
|
s2 = gr.Slider(0.0, 0.5, step=0.01, label="Остановить на (доля шага) — Пара 2", |
|
|
value=self.config.get("s2", 0.30)) |
|
|
d2 = gr.Slider(1, 10, step=1, label="Глубина блока — Пара 2", |
|
|
value=self.config.get("d2", 4)) |
|
|
|
|
|
with gr.Row(): |
|
|
scaler = gr.Dropdown( |
|
|
choices=["bicubic", "bilinear", "nearest", "nearest-exact"], |
|
|
label="Режим интерполяции слоя", |
|
|
value=self.config.get("scaler", "bicubic"), |
|
|
) |
|
|
downscale = gr.Slider(0.1, 1.0, step=0.05, label="Коэффициент даунскейла (вход)", |
|
|
value=self.config.get("downscale", 0.5)) |
|
|
upscale = gr.Slider(1.0, 4.0, step=0.1, label="Коэффициент апскейла (выход)", |
|
|
value=self.config.get("upscale", 2.0)) |
|
|
|
|
|
with gr.Row(): |
|
|
smooth_scaling = gr.Checkbox(label="Плавное изменение масштаба", |
|
|
value=self.config.get("smooth_scaling", True)) |
|
|
keep_unitary_product = gr.Checkbox( |
|
|
label="Сохранять суммарный масштаб = 1 при сглаживании", |
|
|
value=self.config.get("keep_unitary_product", False), |
|
|
) |
|
|
early_out = gr.Checkbox(label="Ранний апскейл на прямом индексе выхода", |
|
|
value=self.config.get("early_out", False)) |
|
|
only_one_pass = gr.Checkbox(label="Только один проход (отключить на следующих шагах)", |
|
|
value=self.config.get("only_one_pass", True)) |
|
|
|
|
|
|
|
|
with gr.Group(): |
|
|
gr.Markdown("**Интерполяция (продвинутое)**") |
|
|
with gr.Row(): |
|
|
align_corners_mode = gr.Dropdown( |
|
|
choices=["False", "True", "Авто"], |
|
|
value=self.config.get("align_corners_mode", "False"), |
|
|
label="align_corners режим", |
|
|
) |
|
|
recompute_scale_factor_mode = gr.Dropdown( |
|
|
choices=["False", "True", "Авто"], |
|
|
value=self.config.get("recompute_scale_factor_mode", "False"), |
|
|
label="recompute_scale_factor режим", |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Group(): |
|
|
gr.Markdown("**Адаптация под разрешение**") |
|
|
with gr.Row(): |
|
|
adaptive_by_resolution = gr.Checkbox( |
|
|
label="Адаптировать параметры под текущее разрешение", |
|
|
value=self.config.get("adaptive_by_resolution", True), |
|
|
) |
|
|
adaptive_profile = gr.Dropdown( |
|
|
choices=["Консервативный", "Сбалансированный", "Агрессивный"], |
|
|
value=self.config.get("adaptive_profile", "Сбалансированный"), |
|
|
label="Профиль адаптации", |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Group(): |
|
|
gr.Markdown("**Именуемые пресеты**") |
|
|
with gr.Row(): |
|
|
preset_select = gr.Dropdown( |
|
|
choices=sorted(list(presets.keys())), |
|
|
value=None, |
|
|
label="Выбрать пресет", |
|
|
) |
|
|
preset_name = gr.Textbox( |
|
|
label="Имя пресета для сохранения/переопределения", |
|
|
placeholder="например: xl-portrait-hires", |
|
|
value="", |
|
|
) |
|
|
with gr.Row(): |
|
|
btn_save = gr.Button("Сохранить как пресет", variant="primary") |
|
|
btn_load = gr.Button("Загрузить пресет") |
|
|
btn_delete = gr.Button("Удалить пресет", variant="stop") |
|
|
preset_status = gr.Markdown("") |
|
|
|
|
|
|
|
|
|
|
|
def _save_preset_cb( |
|
|
name: str, |
|
|
d1_v: int, d2_v: int, s1_v: float, s2_v: float, |
|
|
scaler_v: str, down_v: float, up_v: float, |
|
|
smooth_v: bool, early_v: bool, one_v: bool, keep1_v: bool, |
|
|
align_v: str, recompute_v: str, |
|
|
res_choice_v: str, apply_res_v: bool, |
|
|
adapt_v: bool, adapt_prof_v: str, |
|
|
): |
|
|
name = (name or "").strip() |
|
|
if not name: |
|
|
return gr.update(), "⚠️ Укажите имя пресета." |
|
|
current = _load_presets() |
|
|
current[name] = { |
|
|
"d1": int(d1_v), "d2": int(d2_v), |
|
|
"s1": float(s1_v), "s2": float(s2_v), |
|
|
"scaler": str(scaler_v), |
|
|
"downscale": float(down_v), |
|
|
"upscale": float(up_v), |
|
|
"smooth_scaling": bool(smooth_v), |
|
|
"early_out": bool(early_v), |
|
|
"only_one_pass": bool(one_v), |
|
|
"keep_unitary_product": bool(keep1_v), |
|
|
"align_corners_mode": str(align_v), |
|
|
"recompute_scale_factor_mode": str(recompute_v), |
|
|
"resolution_choice": str(res_choice_v), |
|
|
"apply_resolution": bool(apply_res_v), |
|
|
"adaptive_by_resolution": bool(adapt_v), |
|
|
"adaptive_profile": str(adapt_prof_v), |
|
|
} |
|
|
_save_presets(current) |
|
|
return gr.update(choices=sorted(list(current.keys())), value=name), f"✅ Сохранено пресет «{name}»." |
|
|
|
|
|
btn_save.click( |
|
|
_save_preset_cb, |
|
|
inputs=[ |
|
|
preset_name, |
|
|
d1, d2, s1, s2, |
|
|
scaler, downscale, upscale, |
|
|
smooth_scaling, early_out, only_one_pass, keep_unitary_product, |
|
|
align_corners_mode, recompute_scale_factor_mode, |
|
|
resolution_choice, apply_resolution, |
|
|
adaptive_by_resolution, adaptive_profile, |
|
|
], |
|
|
outputs=[preset_select, preset_status], |
|
|
) |
|
|
|
|
|
def _load_preset_cb(selected: Optional[str]): |
|
|
name = (selected or "").strip() |
|
|
allp = _load_presets() |
|
|
if not name or name not in allp: |
|
|
return ( |
|
|
gr.update(), gr.update(), gr.update(), gr.update(), |
|
|
gr.update(), gr.update(), gr.update(), |
|
|
gr.update(), gr.update(), gr.update(), gr.update(), |
|
|
gr.update(), gr.update(), |
|
|
gr.update(), gr.update(), |
|
|
gr.update(), gr.update(), |
|
|
gr.update(value=name), |
|
|
"⚠️ Пресет не выбран или не найден." |
|
|
) |
|
|
p = allp[name] |
|
|
return ( |
|
|
int(p.get("d1", 3)), |
|
|
int(p.get("d2", 4)), |
|
|
float(p.get("s1", 0.15)), |
|
|
float(p.get("s2", 0.30)), |
|
|
str(p.get("scaler", "bicubic")), |
|
|
float(p.get("downscale", 0.5)), |
|
|
float(p.get("upscale", 2.0)), |
|
|
bool(p.get("smooth_scaling", True)), |
|
|
bool(p.get("early_out", False)), |
|
|
bool(p.get("only_one_pass", True)), |
|
|
bool(p.get("keep_unitary_product", False)), |
|
|
str(p.get("align_corners_mode", "False")), |
|
|
str(p.get("recompute_scale_factor_mode", "False")), |
|
|
str(p.get("resolution_choice", RESOLUTION_CHOICES[0])), |
|
|
bool(p.get("apply_resolution", False)), |
|
|
bool(p.get("adaptive_by_resolution", True)), |
|
|
str(p.get("adaptive_profile", "Сбалансированный")), |
|
|
gr.update(value=name), |
|
|
f"✅ Загружен пресет «{name}».", |
|
|
) |
|
|
|
|
|
btn_load.click( |
|
|
_load_preset_cb, |
|
|
inputs=[preset_select], |
|
|
outputs=[ |
|
|
d1, d2, s1, s2, |
|
|
scaler, downscale, upscale, |
|
|
smooth_scaling, early_out, only_one_pass, keep_unitary_product, |
|
|
align_corners_mode, recompute_scale_factor_mode, |
|
|
resolution_choice, apply_resolution, |
|
|
adaptive_by_resolution, adaptive_profile, |
|
|
preset_name, preset_status, |
|
|
], |
|
|
) |
|
|
|
|
|
def _delete_preset_cb(selected: Optional[str]): |
|
|
name = (selected or "").strip() |
|
|
current = _load_presets() |
|
|
if not name or name not in current: |
|
|
return gr.update(), "⚠️ Пресет не выбран или не найден." |
|
|
current.pop(name, None) |
|
|
_save_presets(current) |
|
|
return gr.update(choices=sorted(list(current.keys())), value=None), f"🗑️ Удалён пресет «{name}»." |
|
|
|
|
|
btn_delete.click( |
|
|
_delete_preset_cb, |
|
|
inputs=[preset_select], |
|
|
outputs=[preset_select, preset_status], |
|
|
) |
|
|
|
|
|
|
|
|
self.infotext_fields.append((enable, lambda d: d.get("DSHF_s1", False))) |
|
|
for k, element in { |
|
|
"DSHF_res": resolution_choice, "DSHF_apply_res": apply_resolution, |
|
|
"DSHF_s1": s1, "DSHF_d1": d1, "DSHF_s2": s2, "DSHF_d2": d2, |
|
|
"DSHF_scaler": scaler, "DSHF_down": downscale, "DSHF_up": upscale, |
|
|
"DSHF_smooth": smooth_scaling, "DSHF_early": early_out, |
|
|
"DSHF_one": only_one_pass, "DSHF_keep1": keep_unitary_product, |
|
|
"DSHF_align": align_corners_mode, "DSHF_recompute": recompute_scale_factor_mode, |
|
|
"DSHF_adapt": adaptive_by_resolution, "DSHF_adapt_profile": adaptive_profile, |
|
|
}.items(): |
|
|
self.infotext_fields.append((element, k)) |
|
|
|
|
|
|
|
|
return [ |
|
|
enable, |
|
|
only_one_pass, d1, d2, s1, s2, scaler, downscale, upscale, |
|
|
smooth_scaling, early_out, keep_unitary_product, |
|
|
align_corners_mode, recompute_scale_factor_mode, |
|
|
resolution_choice, apply_resolution, |
|
|
adaptive_by_resolution, adaptive_profile, |
|
|
|
|
|
preset_select, preset_name, |
|
|
] |
|
|
|
|
|
@staticmethod |
|
|
def _unwrap_all(model) -> None: |
|
|
if not model: |
|
|
return |
|
|
for i, b in enumerate(getattr(model, "input_blocks", [])): |
|
|
if isinstance(b, Scaler): |
|
|
model.input_blocks[i] = b.block |
|
|
for i, b in enumerate(getattr(model, "output_blocks", [])): |
|
|
if isinstance(b, Scaler): |
|
|
model.output_blocks[i] = b.block |
|
|
|
|
|
def process( |
|
|
self, |
|
|
p, |
|
|
enable: bool, |
|
|
only_one_pass: bool, |
|
|
d1: int, |
|
|
d2: int, |
|
|
s1: float, |
|
|
s2: float, |
|
|
scaler: str, |
|
|
downscale: float, |
|
|
upscale: float, |
|
|
smooth_scaling: bool, |
|
|
early_out: bool, |
|
|
keep_unitary_product: bool, |
|
|
align_corners_mode_ui: str, |
|
|
recompute_scale_factor_mode_ui: str, |
|
|
resolution_choice: str, |
|
|
apply_resolution: bool, |
|
|
adaptive_by_resolution: bool, |
|
|
adaptive_profile: str, |
|
|
selected_preset: Optional[str], |
|
|
new_preset_name: str, |
|
|
): |
|
|
|
|
|
align_mode = _norm_mode_choice(align_corners_mode_ui, "false") |
|
|
recompute_mode = _norm_mode_choice(recompute_scale_factor_mode_ui, "false") |
|
|
|
|
|
|
|
|
self.config = DictConfig({ |
|
|
"s1": s1, "s2": s2, "d1": d1, "d2": d2, |
|
|
"scaler": scaler, "downscale": downscale, "upscale": upscale, |
|
|
"smooth_scaling": smooth_scaling, "early_out": early_out, "only_one_pass": only_one_pass, |
|
|
"keep_unitary_product": keep_unitary_product, |
|
|
"align_corners_mode": align_corners_mode_ui, |
|
|
"recompute_scale_factor_mode": recompute_scale_factor_mode_ui, |
|
|
"resolution_choice": resolution_choice, "apply_resolution": apply_resolution, |
|
|
"adaptive_by_resolution": adaptive_by_resolution, "adaptive_profile": adaptive_profile, |
|
|
}) |
|
|
self.step_limit = 0 |
|
|
|
|
|
|
|
|
if apply_resolution: |
|
|
wh = parse_resolution_label(resolution_choice) |
|
|
if wh: |
|
|
p.width, p.height = wh |
|
|
|
|
|
|
|
|
if not enable or self.disable: |
|
|
try: |
|
|
script_callbacks.remove_current_script_callbacks() |
|
|
except Exception: |
|
|
pass |
|
|
self._cb_registered = False |
|
|
try: |
|
|
KohyaHiresFix._unwrap_all(p.sd_model.model.diffusion_model) |
|
|
except Exception: |
|
|
pass |
|
|
return |
|
|
|
|
|
|
|
|
use_s1, use_s2 = s1, s2 |
|
|
use_d1, use_d2 = d1, d2 |
|
|
use_down, use_up = downscale, upscale |
|
|
|
|
|
if adaptive_by_resolution: |
|
|
try: |
|
|
use_s1, use_s2, use_d1, use_d2, use_down, use_up = _compute_adaptive_params( |
|
|
int(p.width), int(p.height), |
|
|
adaptive_profile, |
|
|
s1, s2, d1, d2, |
|
|
downscale, upscale, |
|
|
keep_unitary_product, |
|
|
) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
if use_s1 > use_s2: |
|
|
use_s2 = use_s1 |
|
|
|
|
|
model = p.sd_model.model.diffusion_model |
|
|
max_inp = len(getattr(model, "input_blocks", [])) - 1 |
|
|
if max_inp < 0: |
|
|
return |
|
|
|
|
|
d1_idx = max(0, min(int(use_d1) - 1, max_inp)) |
|
|
d2_idx = max(0, min(int(use_d2) - 1, max_inp)) |
|
|
scaler_mode = _safe_mode(scaler) |
|
|
|
|
|
|
|
|
combined: Dict[int, float] = {} |
|
|
for s_stop, d_idx in ((float(use_s1), d1_idx), (float(use_s2), d2_idx)): |
|
|
combined[d_idx] = max(combined.get(d_idx, 0.0), s_stop) |
|
|
|
|
|
def denoiser_callback(params: script_callbacks.CFGDenoiserParams): |
|
|
if params.sampling_step < self.step_limit: |
|
|
return |
|
|
|
|
|
total = max(1, int(params.total_sampling_steps)) |
|
|
|
|
|
for d_idx, s_stop in combined.items(): |
|
|
out_idx = d_idx if early_out else -(d_idx + 1) |
|
|
try: |
|
|
if params.sampling_step < total * s_stop: |
|
|
if not isinstance(model.input_blocks[d_idx], Scaler): |
|
|
model.input_blocks[d_idx] = Scaler( |
|
|
use_down, model.input_blocks[d_idx], scaler_mode, |
|
|
align_mode, recompute_mode |
|
|
) |
|
|
model.output_blocks[out_idx] = Scaler( |
|
|
use_up, model.output_blocks[out_idx], scaler_mode, |
|
|
align_mode, recompute_mode |
|
|
) |
|
|
|
|
|
if smooth_scaling: |
|
|
ratio = params.sampling_step / (total * s_stop) |
|
|
ratio = float(max(0.0, min(1.0, ratio))) |
|
|
cur_down = min((1.0 - use_down) * ratio + use_down, 1.0) |
|
|
model.input_blocks[d_idx].scale = cur_down |
|
|
|
|
|
if keep_unitary_product: |
|
|
cur_up = 1.0 / max(1e-6, cur_down) |
|
|
else: |
|
|
cur_up = use_up * (use_down / max(1e-6, cur_down)) |
|
|
model.output_blocks[out_idx].scale = cur_up |
|
|
else: |
|
|
if isinstance(model.input_blocks[d_idx], Scaler): |
|
|
model.input_blocks[d_idx] = model.input_blocks[d_idx].block |
|
|
model.output_blocks[out_idx] = model.output_blocks[out_idx].block |
|
|
|
|
|
except Exception: |
|
|
try: |
|
|
KohyaHiresFix._unwrap_all(model) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
self.step_limit = int(params.sampling_step) if only_one_pass else 0 |
|
|
|
|
|
|
|
|
if self._cb_registered: |
|
|
try: |
|
|
script_callbacks.remove_current_script_callbacks() |
|
|
except Exception: |
|
|
pass |
|
|
self._cb_registered = False |
|
|
|
|
|
script_callbacks.on_cfg_denoiser(denoiser_callback) |
|
|
self._cb_registered = True |
|
|
|
|
|
|
|
|
parameters = { |
|
|
"DSHF_res": resolution_choice, "DSHF_apply_res": apply_resolution, |
|
|
"DSHF_s1": use_s1, "DSHF_d1": use_d1, "DSHF_s2": use_s2, "DSHF_d2": use_d2, |
|
|
"DSHF_scaler": scaler_mode, "DSHF_down": use_down, "DSHF_up": use_up, |
|
|
"DSHF_smooth": smooth_scaling, "DSHF_early": early_out, |
|
|
"DSHF_one": only_one_pass, "DSHF_keep1": keep_unitary_product, |
|
|
"DSHF_align": align_corners_mode_ui, "DSHF_recompute": recompute_scale_factor_mode_ui, |
|
|
"DSHF_adapt": adaptive_by_resolution, "DSHF_adapt_profile": adaptive_profile, |
|
|
} |
|
|
for k, v in parameters.items(): |
|
|
p.extra_generation_params[k] = v |
|
|
|
|
|
def postprocess(self, p, processed, *args): |
|
|
try: |
|
|
KohyaHiresFix._unwrap_all(p.sd_model.model.diffusion_model) |
|
|
finally: |
|
|
try: |
|
|
_atomic_save_yaml(CONFIG_PATH, OmegaConf.to_container(self.config, resolve=True) or {}) |
|
|
except Exception: |
|
|
pass |
|
|
self._cb_registered = False |
|
|
|
|
|
def process_batch(self, p, *args, **kwargs): |
|
|
self.step_limit = 0 |
|
|
|