dikdimon's picture
Upload sd-webui-kohya-hiresfix-saveable using SD-Hub
9fd099c verified
# kohya_hires_fix_ru.py
# Версия: 1.5 (RU)
# Совместимость: A1111 / modules.scripts API, PyTorch >= 1.12, OmegaConf >= 2.2
# Новое в 1.5:
# - Переключатели align_corners (Авто/True/False) и recompute_scale_factor (Авто/True/False).
# - Сохранение/загрузка этих параметров в пресетах и конфиге.
# - По умолчанию False/False (как в 1.4) для стабильности и отсутствия предупреждений.
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import gradio as gr
import torch
import torch.nn.functional as F
from omegaconf import DictConfig, OmegaConf
from modules import scripts, script_callbacks
CONFIG_PATH = Path(__file__).with_suffix(".yaml")
PRESETS_PATH = Path(__file__).with_name(Path(__file__).stem + ".presets.yaml")
# ---- Предустановленные разрешения ----
RESOLUTION_GROUPS = {
"Квадрат": [(1024, 1024)],
"Портрет": [(640, 1536), (768, 1344), (832, 1216), (896, 1152)],
"Альбом": [(1536, 640), (1344, 768), (1216, 832), (1152, 896)],
}
RESOLUTION_CHOICES: List[str] = ["— не применять —"]
for group, dims in RESOLUTION_GROUPS.items():
for w, h in dims:
RESOLUTION_CHOICES.append(f"{group}: {w}x{h}")
def parse_resolution_label(label: str) -> Optional[Tuple[int, int]]:
if not label or label.startswith("—"):
return None
try:
_, wh = label.split(":")
w, h = wh.strip().lower().split("x")
return int(w), int(h)
except Exception:
return None
# ---- Вспомогательные утилиты ----
def _safe_mode(mode: str) -> str:
if mode == "nearest-exact":
return mode # при ошибке fallback в forward()
if mode in {"bicubic", "bilinear", "nearest"}:
return mode
return "bilinear"
def _load_yaml(path: Path, default: dict) -> dict:
try:
return OmegaConf.to_container(OmegaConf.load(path), resolve=True) or default
except Exception:
return default
def _atomic_save_yaml(path: Path, data: dict) -> None:
try:
tmp = path.with_suffix(path.suffix + ".tmp")
OmegaConf.save(DictConfig(data), tmp)
tmp.replace(path)
except Exception:
pass
def _load_presets() -> Dict[str, dict]:
data = _load_yaml(PRESETS_PATH, {})
return {str(k): dict(v) for k, v in data.items()}
def _save_presets(presets: Dict[str, dict]) -> None:
_atomic_save_yaml(PRESETS_PATH, presets)
def _clamp(x: float, lo: float, hi: float) -> float:
return float(max(lo, min(hi, x)))
def _norm_mode_choice(value: str, default_: str = "false") -> str:
"""Привести выбор из UI к {'true','false','auto'}."""
s = str(value or "").strip().lower()
if s in ("true",):
return "true"
if s in ("false",):
return "false"
if s in ("авто", "auto"):
return "auto"
return default_
def _compute_adaptive_params(
width: int,
height: int,
profile: str,
base_s1: float,
base_s2: float,
base_d1: int,
base_d2: int,
base_down: float,
base_up: float,
keep_unitary_product: bool,
) -> Tuple[float, float, int, int, float, float]:
"""Адаптировать (s1, s2, d1, d2, downscale, upscale) под MPix и аспект."""
rel_mpx = (max(1, int(width)) * max(1, int(height))) / float(1024 * 1024)
aspect = max(width, height) / float(max(1, min(width, height)))
s_add = 0.0
d_add = 0
down = float(base_down)
# MPix
if rel_mpx >= 1.5:
s_add += 0.08
down -= 0.10
elif rel_mpx >= 1.1:
s_add += 0.05
down -= 0.05
elif rel_mpx <= 0.8:
s_add -= 0.02
down += 0.05
# Аспект
if aspect >= 1.6:
d_add += 1
down -= 0.05
if aspect >= 2.0:
d_add += 1
s_add += 0.02
# Профиль
prof = (profile or "Сбалансированный").strip().lower()
if "консер" in prof:
s_add *= 0.6
down = 0.5 + 0.5 * (down - 0.5)
elif "агресс" in prof:
s_add *= 1.3
down -= 0.05
s1 = _clamp(base_s1 + s_add, 0.0, 0.5)
s2 = _clamp(base_s2 + s_add, 0.0, 0.5)
d1 = max(1, min(10, int(base_d1 + d_add)))
d2 = max(1, min(10, int(base_d2 + d_add)))
down = _clamp(down, 0.3, 0.9)
if keep_unitary_product:
up = 1.0 / max(1e-6, down)
else:
up = float(base_up)
return s1, s2, d1, d2, down, up
# ---- Основные классы ----
class Scaler(torch.nn.Module):
"""Обёртка блока U-Net: масштабировать вход, вызвать исходный модуль."""
def __init__(
self,
scale: float,
block: torch.nn.Module,
scaler: str,
align_mode: str = "false", # 'true' | 'false' | 'auto'
recompute_mode: str = "false", # 'true' | 'false' | 'auto'
) -> None:
super().__init__()
self.scale: float = float(scale)
self.block: torch.nn.Module = block
self.scaler: str = _safe_mode(scaler)
self.align_mode: str = _norm_mode_choice(align_mode, "false")
self.recompute_mode: str = _norm_mode_choice(recompute_mode, "false")
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
mode = self.scaler
try:
kw = dict(scale_factor=self.scale, mode=mode)
# align_corners только для линейных режимов
if mode in ("bilinear", "bicubic"):
if self.align_mode == "true":
kw["align_corners"] = True
elif self.align_mode == "false":
kw["align_corners"] = False
# 'auto' -> не передаём параметр
# recompute_scale_factor для любых режимов
if self.recompute_mode == "true":
kw["recompute_scale_factor"] = True
elif self.recompute_mode == "false":
kw["recompute_scale_factor"] = False
# 'auto' -> не передаём параметр
x = F.interpolate(x, **kw)
except Exception:
# Фоллбек при несовместимом режиме
safe = "nearest" if mode == "nearest-exact" else "bilinear"
kw = dict(scale_factor=self.scale, mode=safe)
if safe in ("bilinear", "bicubic"):
if self.align_mode == "true":
kw["align_corners"] = True
elif self.align_mode == "false":
kw["align_corners"] = False
if self.recompute_mode == "true":
kw["recompute_scale_factor"] = True
elif self.recompute_mode == "false":
kw["recompute_scale_factor"] = False
x = F.interpolate(x, **kw)
return self.block(x, *args, **kwargs)
class KohyaHiresFix(scripts.Script):
"""Динамический hires.fix через временную смену масштаба внутренних фич U-Net."""
def __init__(self) -> None:
super().__init__()
self.config: DictConfig = DictConfig(_load_yaml(CONFIG_PATH, {}))
self.disable: bool = False
self.step_limit: int = 0
self.infotext_fields = []
self._cb_registered: bool = False
self.p1: Tuple[float, int] = (0.15, 2)
self.p2: Tuple[float, int] = (0.30, 3)
def title(self) -> str:
return "Kohya Hires.fix · Русская версия"
def show(self, is_img2img: bool):
return scripts.AlwaysVisible
def ui(self, is_img2img: bool):
# Сброс infotext при горячей перезагрузке
self.infotext_fields = []
presets = _load_presets()
with gr.Accordion(label="Kohya Hires.fix", open=False):
enable = gr.Checkbox(label="Включить расширение", value=False)
# Разрешения
with gr.Group():
gr.Markdown("**Предустановленные разрешения**")
with gr.Row():
resolution_choice = gr.Dropdown(
choices=RESOLUTION_CHOICES,
value=self.config.get("resolution_choice", RESOLUTION_CHOICES[0]),
label="Выбрать разрешение",
)
apply_resolution = gr.Checkbox(
label="Применять выбранное разрешение к ширине/высоте",
value=self.config.get("apply_resolution", False),
)
# Параметры масштабирования
with gr.Group():
gr.Markdown("**Параметры масштабирования**")
with gr.Row():
s1 = gr.Slider(0.0, 0.5, step=0.01, label="Остановить на (доля шага) — Пара 1",
value=self.config.get("s1", 0.15))
d1 = gr.Slider(1, 10, step=1, label="Глубина блока — Пара 1",
value=self.config.get("d1", 3))
with gr.Row():
s2 = gr.Slider(0.0, 0.5, step=0.01, label="Остановить на (доля шага) — Пара 2",
value=self.config.get("s2", 0.30))
d2 = gr.Slider(1, 10, step=1, label="Глубина блока — Пара 2",
value=self.config.get("d2", 4))
with gr.Row():
scaler = gr.Dropdown(
choices=["bicubic", "bilinear", "nearest", "nearest-exact"],
label="Режим интерполяции слоя",
value=self.config.get("scaler", "bicubic"),
)
downscale = gr.Slider(0.1, 1.0, step=0.05, label="Коэффициент даунскейла (вход)",
value=self.config.get("downscale", 0.5))
upscale = gr.Slider(1.0, 4.0, step=0.1, label="Коэффициент апскейла (выход)",
value=self.config.get("upscale", 2.0))
with gr.Row():
smooth_scaling = gr.Checkbox(label="Плавное изменение масштаба",
value=self.config.get("smooth_scaling", True))
keep_unitary_product = gr.Checkbox(
label="Сохранять суммарный масштаб = 1 при сглаживании",
value=self.config.get("keep_unitary_product", False),
)
early_out = gr.Checkbox(label="Ранний апскейл на прямом индексе выхода",
value=self.config.get("early_out", False))
only_one_pass = gr.Checkbox(label="Только один проход (отключить на следующих шагах)",
value=self.config.get("only_one_pass", True))
# Интерполяция: новые переключатели
with gr.Group():
gr.Markdown("**Интерполяция (продвинутое)**")
with gr.Row():
align_corners_mode = gr.Dropdown(
choices=["False", "True", "Авто"],
value=self.config.get("align_corners_mode", "False"),
label="align_corners режим",
)
recompute_scale_factor_mode = gr.Dropdown(
choices=["False", "True", "Авто"],
value=self.config.get("recompute_scale_factor_mode", "False"),
label="recompute_scale_factor режим",
)
# Адаптация
with gr.Group():
gr.Markdown("**Адаптация под разрешение**")
with gr.Row():
adaptive_by_resolution = gr.Checkbox(
label="Адаптировать параметры под текущее разрешение",
value=self.config.get("adaptive_by_resolution", True),
)
adaptive_profile = gr.Dropdown(
choices=["Консервативный", "Сбалансированный", "Агрессивный"],
value=self.config.get("adaptive_profile", "Сбалансированный"),
label="Профиль адаптации",
)
# Пресеты
with gr.Group():
gr.Markdown("**Именуемые пресеты**")
with gr.Row():
preset_select = gr.Dropdown(
choices=sorted(list(presets.keys())),
value=None,
label="Выбрать пресет",
)
preset_name = gr.Textbox(
label="Имя пресета для сохранения/переопределения",
placeholder="например: xl-portrait-hires",
value="",
)
with gr.Row():
btn_save = gr.Button("Сохранить как пресет", variant="primary")
btn_load = gr.Button("Загрузить пресет")
btn_delete = gr.Button("Удалить пресет", variant="stop")
preset_status = gr.Markdown("")
# Коллбеки пресетов
def _save_preset_cb(
name: str,
d1_v: int, d2_v: int, s1_v: float, s2_v: float,
scaler_v: str, down_v: float, up_v: float,
smooth_v: bool, early_v: bool, one_v: bool, keep1_v: bool,
align_v: str, recompute_v: str,
res_choice_v: str, apply_res_v: bool,
adapt_v: bool, adapt_prof_v: str,
):
name = (name or "").strip()
if not name:
return gr.update(), "⚠️ Укажите имя пресета."
current = _load_presets()
current[name] = {
"d1": int(d1_v), "d2": int(d2_v),
"s1": float(s1_v), "s2": float(s2_v),
"scaler": str(scaler_v),
"downscale": float(down_v),
"upscale": float(up_v),
"smooth_scaling": bool(smooth_v),
"early_out": bool(early_v),
"only_one_pass": bool(one_v),
"keep_unitary_product": bool(keep1_v),
"align_corners_mode": str(align_v),
"recompute_scale_factor_mode": str(recompute_v),
"resolution_choice": str(res_choice_v),
"apply_resolution": bool(apply_res_v),
"adaptive_by_resolution": bool(adapt_v),
"adaptive_profile": str(adapt_prof_v),
}
_save_presets(current)
return gr.update(choices=sorted(list(current.keys())), value=name), f"✅ Сохранено пресет «{name}»."
btn_save.click(
_save_preset_cb,
inputs=[
preset_name,
d1, d2, s1, s2,
scaler, downscale, upscale,
smooth_scaling, early_out, only_one_pass, keep_unitary_product,
align_corners_mode, recompute_scale_factor_mode,
resolution_choice, apply_resolution,
adaptive_by_resolution, adaptive_profile,
],
outputs=[preset_select, preset_status],
)
def _load_preset_cb(selected: Optional[str]):
name = (selected or "").strip()
allp = _load_presets()
if not name or name not in allp:
return (
gr.update(), gr.update(), gr.update(), gr.update(),
gr.update(), gr.update(), gr.update(),
gr.update(), gr.update(), gr.update(), gr.update(),
gr.update(), gr.update(),
gr.update(), gr.update(),
gr.update(), gr.update(),
gr.update(value=name),
"⚠️ Пресет не выбран или не найден."
)
p = allp[name]
return (
int(p.get("d1", 3)),
int(p.get("d2", 4)),
float(p.get("s1", 0.15)),
float(p.get("s2", 0.30)),
str(p.get("scaler", "bicubic")),
float(p.get("downscale", 0.5)),
float(p.get("upscale", 2.0)),
bool(p.get("smooth_scaling", True)),
bool(p.get("early_out", False)),
bool(p.get("only_one_pass", True)),
bool(p.get("keep_unitary_product", False)),
str(p.get("align_corners_mode", "False")),
str(p.get("recompute_scale_factor_mode", "False")),
str(p.get("resolution_choice", RESOLUTION_CHOICES[0])),
bool(p.get("apply_resolution", False)),
bool(p.get("adaptive_by_resolution", True)),
str(p.get("adaptive_profile", "Сбалансированный")),
gr.update(value=name),
f"✅ Загружен пресет «{name}».",
)
btn_load.click(
_load_preset_cb,
inputs=[preset_select],
outputs=[
d1, d2, s1, s2,
scaler, downscale, upscale,
smooth_scaling, early_out, only_one_pass, keep_unitary_product,
align_corners_mode, recompute_scale_factor_mode,
resolution_choice, apply_resolution,
adaptive_by_resolution, adaptive_profile,
preset_name, preset_status,
],
)
def _delete_preset_cb(selected: Optional[str]):
name = (selected or "").strip()
current = _load_presets()
if not name or name not in current:
return gr.update(), "⚠️ Пресет не выбран или не найден."
current.pop(name, None)
_save_presets(current)
return gr.update(choices=sorted(list(current.keys())), value=None), f"🗑️ Удалён пресет «{name}»."
btn_delete.click(
_delete_preset_cb,
inputs=[preset_select],
outputs=[preset_select, preset_status],
)
# Поля для infotext
self.infotext_fields.append((enable, lambda d: d.get("DSHF_s1", False)))
for k, element in {
"DSHF_res": resolution_choice, "DSHF_apply_res": apply_resolution,
"DSHF_s1": s1, "DSHF_d1": d1, "DSHF_s2": s2, "DSHF_d2": d2,
"DSHF_scaler": scaler, "DSHF_down": downscale, "DSHF_up": upscale,
"DSHF_smooth": smooth_scaling, "DSHF_early": early_out,
"DSHF_one": only_one_pass, "DSHF_keep1": keep_unitary_product,
"DSHF_align": align_corners_mode, "DSHF_recompute": recompute_scale_factor_mode,
"DSHF_adapt": adaptive_by_resolution, "DSHF_adapt_profile": adaptive_profile,
}.items():
self.infotext_fields.append((element, k))
# Порядок должен соответствовать process(...)
return [
enable,
only_one_pass, d1, d2, s1, s2, scaler, downscale, upscale,
smooth_scaling, early_out, keep_unitary_product,
align_corners_mode, recompute_scale_factor_mode,
resolution_choice, apply_resolution,
adaptive_by_resolution, adaptive_profile,
# пресеты (в process не участвуют)
preset_select, preset_name,
]
@staticmethod
def _unwrap_all(model) -> None:
if not model:
return
for i, b in enumerate(getattr(model, "input_blocks", [])):
if isinstance(b, Scaler):
model.input_blocks[i] = b.block
for i, b in enumerate(getattr(model, "output_blocks", [])):
if isinstance(b, Scaler):
model.output_blocks[i] = b.block
def process(
self,
p,
enable: bool,
only_one_pass: bool,
d1: int,
d2: int,
s1: float,
s2: float,
scaler: str,
downscale: float,
upscale: float,
smooth_scaling: bool,
early_out: bool,
keep_unitary_product: bool,
align_corners_mode_ui: str,
recompute_scale_factor_mode_ui: str,
resolution_choice: str,
apply_resolution: bool,
adaptive_by_resolution: bool,
adaptive_profile: str,
selected_preset: Optional[str],
new_preset_name: str,
):
# Нормализовать режимы интерполяции из UI
align_mode = _norm_mode_choice(align_corners_mode_ui, "false")
recompute_mode = _norm_mode_choice(recompute_scale_factor_mode_ui, "false")
# Сохранить конфиг последних значений
self.config = DictConfig({
"s1": s1, "s2": s2, "d1": d1, "d2": d2,
"scaler": scaler, "downscale": downscale, "upscale": upscale,
"smooth_scaling": smooth_scaling, "early_out": early_out, "only_one_pass": only_one_pass,
"keep_unitary_product": keep_unitary_product,
"align_corners_mode": align_corners_mode_ui,
"recompute_scale_factor_mode": recompute_scale_factor_mode_ui,
"resolution_choice": resolution_choice, "apply_resolution": apply_resolution,
"adaptive_by_resolution": adaptive_by_resolution, "adaptive_profile": adaptive_profile,
})
self.step_limit = 0
# Применить выбранное разрешение
if apply_resolution:
wh = parse_resolution_label(resolution_choice)
if wh:
p.width, p.height = wh
# Выключено — снять коллбеки и обёртки
if not enable or self.disable:
try:
script_callbacks.remove_current_script_callbacks()
except Exception:
pass
self._cb_registered = False
try:
KohyaHiresFix._unwrap_all(p.sd_model.model.diffusion_model)
except Exception:
pass
return
# Адаптация значений под фактическое разрешение
use_s1, use_s2 = s1, s2
use_d1, use_d2 = d1, d2
use_down, use_up = downscale, upscale
if adaptive_by_resolution:
try:
use_s1, use_s2, use_d1, use_d2, use_down, use_up = _compute_adaptive_params(
int(p.width), int(p.height),
adaptive_profile,
s1, s2, d1, d2,
downscale, upscale,
keep_unitary_product,
)
except Exception:
pass
if use_s1 > use_s2:
use_s2 = use_s1
model = p.sd_model.model.diffusion_model
max_inp = len(getattr(model, "input_blocks", [])) - 1
if max_inp < 0:
return
d1_idx = max(0, min(int(use_d1) - 1, max_inp))
d2_idx = max(0, min(int(use_d2) - 1, max_inp))
scaler_mode = _safe_mode(scaler)
# Объединить пары по глубине
combined: Dict[int, float] = {}
for s_stop, d_idx in ((float(use_s1), d1_idx), (float(use_s2), d2_idx)):
combined[d_idx] = max(combined.get(d_idx, 0.0), s_stop)
def denoiser_callback(params: script_callbacks.CFGDenoiserParams):
if params.sampling_step < self.step_limit:
return
total = max(1, int(params.total_sampling_steps))
for d_idx, s_stop in combined.items():
out_idx = d_idx if early_out else -(d_idx + 1)
try:
if params.sampling_step < total * s_stop:
if not isinstance(model.input_blocks[d_idx], Scaler):
model.input_blocks[d_idx] = Scaler(
use_down, model.input_blocks[d_idx], scaler_mode,
align_mode, recompute_mode
)
model.output_blocks[out_idx] = Scaler(
use_up, model.output_blocks[out_idx], scaler_mode,
align_mode, recompute_mode
)
if smooth_scaling:
ratio = params.sampling_step / (total * s_stop)
ratio = float(max(0.0, min(1.0, ratio)))
cur_down = min((1.0 - use_down) * ratio + use_down, 1.0)
model.input_blocks[d_idx].scale = cur_down
if keep_unitary_product:
cur_up = 1.0 / max(1e-6, cur_down)
else:
cur_up = use_up * (use_down / max(1e-6, cur_down))
model.output_blocks[out_idx].scale = cur_up
else:
if isinstance(model.input_blocks[d_idx], Scaler):
model.input_blocks[d_idx] = model.input_blocks[d_idx].block
model.output_blocks[out_idx] = model.output_blocks[out_idx].block
except Exception:
try:
KohyaHiresFix._unwrap_all(model)
except Exception:
pass
self.step_limit = int(params.sampling_step) if only_one_pass else 0
# Обновить коллбек
if self._cb_registered:
try:
script_callbacks.remove_current_script_callbacks()
except Exception:
pass
self._cb_registered = False
script_callbacks.on_cfg_denoiser(denoiser_callback)
self._cb_registered = True
# Инфотекст: фактические значения + выбранные режимы
parameters = {
"DSHF_res": resolution_choice, "DSHF_apply_res": apply_resolution,
"DSHF_s1": use_s1, "DSHF_d1": use_d1, "DSHF_s2": use_s2, "DSHF_d2": use_d2,
"DSHF_scaler": scaler_mode, "DSHF_down": use_down, "DSHF_up": use_up,
"DSHF_smooth": smooth_scaling, "DSHF_early": early_out,
"DSHF_one": only_one_pass, "DSHF_keep1": keep_unitary_product,
"DSHF_align": align_corners_mode_ui, "DSHF_recompute": recompute_scale_factor_mode_ui,
"DSHF_adapt": adaptive_by_resolution, "DSHF_adapt_profile": adaptive_profile,
}
for k, v in parameters.items():
p.extra_generation_params[k] = v
def postprocess(self, p, processed, *args):
try:
KohyaHiresFix._unwrap_all(p.sd_model.model.diffusion_model)
finally:
try:
_atomic_save_yaml(CONFIG_PATH, OmegaConf.to_container(self.config, resolve=True) or {})
except Exception:
pass
self._cb_registered = False
def process_batch(self, p, *args, **kwargs):
self.step_limit = 0