# kohya_hires_fix_ru.py # Версия: 1.6 (RU) # Совместимость: A1111 / modules.scripts API, PyTorch >= 1.12, OmegaConf >= 2.2 # Новое в 1.6: # - Безопасный маппинг индексов вход/выход U-Net (patch #1) + защита при отсутствии пар. # - Fail-safe на исключениях в коллбэке: снятие коллбэка, unwrap, self.disable=True (patch #2). # - Опциональная кривая сглаживания Smoothstep (patch #3) + ключ в пресетах и конфиге. # - Лёгкая диагностика первой итерации (patch #4). # # Новое в 1.5: # - Переключатели align_corners (Авто/True/False) и recompute_scale_factor (Авто/True/False). # - Сохранение/загрузка этих параметров в пресетах и конфиге. # - По умолчанию False/False (как в 1.4) для стабильности и отсутствия предупреждений. from __future__ import annotations from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import gradio as gr import torch import torch.nn.functional as F from omegaconf import DictConfig, OmegaConf from modules import scripts, script_callbacks CONFIG_PATH = Path(__file__).with_suffix(".yaml") PRESETS_PATH = Path(__file__).with_name(Path(__file__).stem + ".presets.yaml") # ---- Предустановленные разрешения ---- RESOLUTION_GROUPS = { "Квадрат": [(1024, 1024)], "Портрет": [(640, 1536), (768, 1344), (832, 1216), (896, 1152)], "Альбом": [(1536, 640), (1344, 768), (1216, 832), (1152, 896)], } RESOLUTION_CHOICES: List[str] = ["— не применять —"] for group, dims in RESOLUTION_GROUPS.items(): for w, h in dims: RESOLUTION_CHOICES.append(f"{group}: {w}x{h}") def parse_resolution_label(label: str) -> Optional[Tuple[int, int]]: if not label or label.startswith("—"): return None try: _, wh = label.split(":") w, h = wh.strip().lower().split("x") return int(w), int(h) except Exception: return None # ---- Вспомогательные утилиты ---- def _safe_mode(mode: str) -> str: if mode == "nearest-exact": return mode # при ошибке fallback в forward() if mode in {"bicubic", "bilinear", "nearest"}: return mode return "bilinear" def _load_yaml(path: Path, default: dict) -> dict: try: return OmegaConf.to_container(OmegaConf.load(path), resolve=True) or default except Exception: return default def _atomic_save_yaml(path: Path, data: dict) -> None: try: tmp = path.with_suffix(path.suffix + ".tmp") OmegaConf.save(DictConfig(data), tmp) tmp.replace(path) except Exception: pass def _load_presets() -> Dict[str, dict]: data = _load_yaml(PRESETS_PATH, {}) return {str(k): dict(v) for k, v in data.items()} def _save_presets(presets: Dict[str, dict]) -> None: _atomic_save_yaml(PRESETS_PATH, presets) def _clamp(x: float, lo: float, hi: float) -> float: return float(max(lo, min(hi, x))) def _norm_mode_choice(value: str, default_: str = "false") -> str: """Привести выбор из UI к {'true','false','auto'}.""" s = str(value or "").strip().lower() if s in ("true",): return "true" if s in ("false",): return "false" if s in ("авто", "auto"): return "auto" return default_ def _compute_adaptive_params( width: int, height: int, profile: str, base_s1: float, base_s2: float, base_d1: int, base_d2: int, base_down: float, base_up: float, keep_unitary_product: bool, ) -> Tuple[float, float, int, int, float, float]: """Адаптировать (s1, s2, d1, d2, downscale, upscale) под MPix и аспект.""" rel_mpx = (max(1, int(width)) * max(1, int(height))) / float(1024 * 1024) aspect = max(width, height) / float(max(1, min(width, height))) s_add = 0.0 d_add = 0 down = float(base_down) # MPix if rel_mpx >= 1.5: s_add += 0.08 down -= 0.10 elif rel_mpx >= 1.1: s_add += 0.05 down -= 0.05 elif rel_mpx <= 0.8: s_add -= 0.02 down += 0.05 # Аспект if aspect >= 1.6: d_add += 1 down -= 0.05 if aspect >= 2.0: d_add += 1 s_add += 0.02 # Профиль prof = (profile or "Сбалансированный").strip().lower() if "консер" in prof: s_add *= 0.6 down = 0.5 + 0.5 * (down - 0.5) elif "агресс" in prof: s_add *= 1.3 down -= 0.05 s1 = _clamp(base_s1 + s_add, 0.0, 0.5) s2 = _clamp(base_s2 + s_add, 0.0, 0.5) d1 = max(1, min(10, int(base_d1 + d_add))) d2 = max(1, min(10, int(base_d2 + d_add))) down = _clamp(down, 0.3, 0.9) if keep_unitary_product: up = 1.0 / max(1e-6, down) else: up = float(base_up) return s1, s2, d1, d2, down, up # ---- Основные классы ---- class Scaler(torch.nn.Module): """Обёртка блока U-Net: масштабировать вход, вызвать исходный модуль.""" def __init__( self, scale: float, block: torch.nn.Module, scaler: str, align_mode: str = "false", # 'true' | 'false' | 'auto' recompute_mode: str = "false", # 'true' | 'false' | 'auto' ) -> None: super().__init__() self.scale: float = float(scale) self.block: torch.nn.Module = block self.scaler: str = _safe_mode(scaler) self.align_mode: str = _norm_mode_choice(align_mode, "false") self.recompute_mode: str = _norm_mode_choice(recompute_mode, "false") def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: mode = self.scaler try: kw = dict(scale_factor=self.scale, mode=mode) # align_corners только для линейных режимов if mode in ("bilinear", "bicubic"): if self.align_mode == "true": kw["align_corners"] = True elif self.align_mode == "false": kw["align_corners"] = False # 'auto' -> не передаём параметр # recompute_scale_factor для любых режимов if self.recompute_mode == "true": kw["recompute_scale_factor"] = True elif self.recompute_mode == "false": kw["recompute_scale_factor"] = False # 'auto' -> не передаём параметр x = F.interpolate(x, **kw) except Exception: # Фоллбек при несовместимом режиме safe = "nearest" if mode == "nearest-exact" else "bilinear" kw = dict(scale_factor=self.scale, mode=safe) if safe in ("bilinear", "bicubic"): if self.align_mode == "true": kw["align_corners"] = True elif self.align_mode == "false": kw["align_corners"] = False if self.recompute_mode == "true": kw["recompute_scale_factor"] = True elif self.recompute_mode == "false": kw["recompute_scale_factor"] = False x = F.interpolate(x, **kw) return self.block(x, *args, **kwargs) class KohyaHiresFix(scripts.Script): """Динамический hires.fix через временную смену масштаба внутренних фич U-Net.""" def __init__(self) -> None: super().__init__() self.config: DictConfig = DictConfig(_load_yaml(CONFIG_PATH, {})) self.disable: bool = False self.step_limit: int = 0 self.infotext_fields = [] self._cb_registered: bool = False def title(self) -> str: return "Kohya Hires.fix · Русская версия" def show(self, is_img2img: bool): return scripts.AlwaysVisible def ui(self, is_img2img: bool): # Сброс infotext при горячей перезагрузке self.infotext_fields = [] presets = _load_presets() with gr.Accordion(label="Kohya Hires.fix", open=False): enable = gr.Checkbox(label="Включить расширение", value=False) # Разрешения with gr.Group(): gr.Markdown("**Предустановленные разрешения**") with gr.Row(): resolution_choice = gr.Dropdown( choices=RESOLUTION_CHOICES, value=self.config.get("resolution_choice", RESOLUTION_CHOICES[0]), label="Выбрать разрешение", ) apply_resolution = gr.Checkbox( label="Применять выбранное разрешение к ширине/высоте", value=self.config.get("apply_resolution", False), ) # Параметры масштабирования with gr.Group(): gr.Markdown("**Параметры масштабирования**") with gr.Row(): s1 = gr.Slider(0.0, 0.5, step=0.01, label="Остановить на (доля шага) — Пара 1", value=self.config.get("s1", 0.15)) d1 = gr.Slider(1, 10, step=1, label="Глубина блока — Пара 1", value=self.config.get("d1", 3)) with gr.Row(): s2 = gr.Slider(0.0, 0.5, step=0.01, label="Остановить на (доля шага) — Пара 2", value=self.config.get("s2", 0.30)) d2 = gr.Slider(1, 10, step=1, label="Глубина блока — Пара 2", value=self.config.get("d2", 4)) with gr.Row(): scaler = gr.Dropdown( choices=["bicubic", "bilinear", "nearest", "nearest-exact"], label="Режим интерполяции слоя", value=self.config.get("scaler", "bicubic"), ) downscale = gr.Slider(0.1, 1.0, step=0.05, label="Коэффициент даунскейла (вход)", value=self.config.get("downscale", 0.5)) upscale = gr.Slider(1.0, 4.0, step=0.1, label="Коэффициент апскейла (выход)", value=self.config.get("upscale", 2.0)) with gr.Row(): smooth_scaling = gr.Checkbox(label="Плавное изменение масштаба", value=self.config.get("smooth_scaling", True)) smoothing_curve = gr.Dropdown( choices=["Линейная", "Smoothstep"], value=self.config.get("smoothing_curve", "Линейная"), label="Кривая сглаживания", ) keep_unitary_product = gr.Checkbox( label="Сохранять суммарный масштаб = 1 при сглаживании", value=self.config.get("keep_unitary_product", False), ) early_out = gr.Checkbox(label="Ранний апскейл на прямом индексе выхода", value=self.config.get("early_out", False)) only_one_pass = gr.Checkbox(label="Только один проход (отключить на следующих шагах)", value=self.config.get("only_one_pass", True)) # Интерполяция: переключатели with gr.Group(): gr.Markdown("**Интерполяция (продвинутое)**") with gr.Row(): align_corners_mode = gr.Dropdown( choices=["False", "True", "Авто"], value=self.config.get("align_corners_mode", "False"), label="align_corners режим", ) recompute_scale_factor_mode = gr.Dropdown( choices=["False", "True", "Авто"], value=self.config.get("recompute_scale_factor_mode", "False"), label="recompute_scale_factor режим", ) # Адаптация with gr.Group(): gr.Markdown("**Адаптация под разрешение**") with gr.Row(): adaptive_by_resolution = gr.Checkbox( label="Адаптировать параметры под текущее разрешение", value=self.config.get("adaptive_by_resolution", True), ) adaptive_profile = gr.Dropdown( choices=["Консервативный", "Сбалансированный", "Агрессивный"], value=self.config.get("adaptive_profile", "Сбалансированный"), label="Профиль адаптации", ) # Пресеты with gr.Group(): gr.Markdown("**Именуемые пресеты**") with gr.Row(): preset_select = gr.Dropdown( choices=sorted(list(presets.keys())), value=None, label="Выбрать пресет", ) preset_name = gr.Textbox( label="Имя пресета для сохранения/переопределения", placeholder="например: xl-portrait-hires", value="", ) with gr.Row(): btn_save = gr.Button("Сохранить как пресет", variant="primary") btn_load = gr.Button("Загрузить пресет") btn_delete = gr.Button("Удалить пресет", variant="stop") preset_status = gr.Markdown("") # Коллбеки пресетов def _save_preset_cb( name: str, d1_v: int, d2_v: int, s1_v: float, s2_v: float, scaler_v: str, down_v: float, up_v: float, smooth_v: bool, smooth_curve_v: str, early_v: bool, one_v: bool, keep1_v: bool, align_v: str, recompute_v: str, res_choice_v: str, apply_res_v: bool, adapt_v: bool, adapt_prof_v: str, ): name = (name or "").strip() if not name: return gr.update(), "⚠️ Укажите имя пресета." current = _load_presets() current[name] = { "d1": int(d1_v), "d2": int(d2_v), "s1": float(s1_v), "s2": float(s2_v), "scaler": str(scaler_v), "downscale": float(down_v), "upscale": float(up_v), "smooth_scaling": bool(smooth_v), "smoothing_curve": str(smooth_curve_v), "early_out": bool(early_v), "only_one_pass": bool(one_v), "keep_unitary_product": bool(keep1_v), "align_corners_mode": str(align_v), "recompute_scale_factor_mode": str(recompute_v), "resolution_choice": str(res_choice_v), "apply_resolution": bool(apply_res_v), "adaptive_by_resolution": bool(adapt_v), "adaptive_profile": str(adapt_prof_v), } _save_presets(current) return gr.update(choices=sorted(list(current.keys())), value=name), f"✅ Сохранено пресет «{name}»." btn_save.click( _save_preset_cb, inputs=[ preset_name, d1, d2, s1, s2, scaler, downscale, upscale, smooth_scaling, smoothing_curve, early_out, only_one_pass, keep_unitary_product, align_corners_mode, recompute_scale_factor_mode, resolution_choice, apply_resolution, adaptive_by_resolution, adaptive_profile, ], outputs=[preset_select, preset_status], ) def _load_preset_cb(selected: Optional[str]): name = (selected or "").strip() allp = _load_presets() if not name or name not in allp: return ( gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(value=name), "⚠️ Пресет не выбран или не найден." ) p = allp[name] return ( int(p.get("d1", 3)), int(p.get("d2", 4)), float(p.get("s1", 0.15)), float(p.get("s2", 0.30)), str(p.get("scaler", "bicubic")), float(p.get("downscale", 0.5)), float(p.get("upscale", 2.0)), bool(p.get("smooth_scaling", True)), str(p.get("smoothing_curve", "Линейная")), bool(p.get("early_out", False)), bool(p.get("only_one_pass", True)), bool(p.get("keep_unitary_product", False)), str(p.get("align_corners_mode", "False")), str(p.get("recompute_scale_factor_mode", "False")), str(p.get("resolution_choice", RESOLUTION_CHOICES[0])), bool(p.get("apply_resolution", False)), bool(p.get("adaptive_by_resolution", True)), str(p.get("adaptive_profile", "Сбалансированный")), gr.update(value=name), f"✅ Загружен пресет «{name}».", ) btn_load.click( _load_preset_cb, inputs=[preset_select], outputs=[ d1, d2, s1, s2, scaler, downscale, upscale, smooth_scaling, smoothing_curve, early_out, only_one_pass, keep_unitary_product, align_corners_mode, recompute_scale_factor_mode, resolution_choice, apply_resolution, adaptive_by_resolution, adaptive_profile, preset_name, preset_status, ], ) def _delete_preset_cb(selected: Optional[str]): name = (selected or "").strip() current = _load_presets() if not name or name not in current: return gr.update(), "⚠️ Пресет не выбран или не найден." current.pop(name, None) _save_presets(current) return gr.update(choices=sorted(list(current.keys())), value=None), f"🗑️ Удалён пресет «{name}»." btn_delete.click( _delete_preset_cb, inputs=[preset_select], outputs=[preset_select, preset_status], ) # Поля для infotext self.infotext_fields.append((enable, lambda d: d.get("DSHF_s1", False))) for k, element in { "DSHF_res": resolution_choice, "DSHF_apply_res": apply_resolution, "DSHF_s1": s1, "DSHF_d1": d1, "DSHF_s2": s2, "DSHF_d2": d2, "DSHF_scaler": scaler, "DSHF_down": downscale, "DSHF_up": upscale, "DSHF_smooth": smooth_scaling, "DSHF_smooth_curve": smoothing_curve, "DSHF_early": early_out, "DSHF_one": only_one_pass, "DSHF_keep1": keep_unitary_product, "DSHF_align": align_corners_mode, "DSHF_recompute": recompute_scale_factor_mode, "DSHF_adapt": adaptive_by_resolution, "DSHF_adapt_profile": adaptive_profile, }.items(): self.infotext_fields.append((element, k)) # Порядок должен соответствовать process(...) return [ enable, only_one_pass, d1, d2, s1, s2, scaler, downscale, upscale, smooth_scaling, smoothing_curve, early_out, keep_unitary_product, align_corners_mode, recompute_scale_factor_mode, resolution_choice, apply_resolution, adaptive_by_resolution, adaptive_profile, # пресеты (в process не участвуют) preset_select, preset_name, ] @staticmethod def _unwrap_all(model) -> None: if not model: return for i, b in enumerate(getattr(model, "input_blocks", [])): if isinstance(b, Scaler): model.input_blocks[i] = b.block for i, b in enumerate(getattr(model, "output_blocks", [])): if isinstance(b, Scaler): model.output_blocks[i] = b.block @staticmethod def _map_output_index(model, in_idx: int, early_out: bool) -> Optional[int]: """ Безопасно сопоставить индекс входного блока индексу выходного. - early_out=True: используем тот же «глубинный» индекс, зажатый по длине output_blocks. - early_out=False: зеркалим относительно конца output_blocks. """ outs = getattr(model, "output_blocks", None) if not outs: return None n_out = len(outs) if n_out == 0: return None if early_out: # прямое соответствие: clamp в [0, n_out-1] return max(0, min(int(in_idx), n_out - 1)) # зеркальное соответствие: последний ↔ 0-й mirror = (n_out - 1) - int(in_idx) return max(0, min(mirror, n_out - 1)) def process( self, p, enable: bool, only_one_pass: bool, d1: int, d2: int, s1: float, s2: float, scaler: str, downscale: float, upscale: float, smooth_scaling: bool, smoothing_curve: str, early_out: bool, keep_unitary_product: bool, align_corners_mode_ui: str, recompute_scale_factor_mode_ui: str, resolution_choice: str, apply_resolution: bool, adaptive_by_resolution: bool, adaptive_profile: str, selected_preset: Optional[str], new_preset_name: str, ): # Нормализовать режимы интерполяции из UI align_mode = _norm_mode_choice(align_corners_mode_ui, "false") recompute_mode = _norm_mode_choice(recompute_scale_factor_mode_ui, "false") # Сохранить конфиг последних значений self.config = DictConfig({ "s1": s1, "s2": s2, "d1": d1, "d2": d2, "scaler": scaler, "downscale": downscale, "upscale": upscale, "smooth_scaling": smooth_scaling, "smoothing_curve": smoothing_curve, "early_out": early_out, "only_one_pass": only_one_pass, "keep_unitary_product": keep_unitary_product, "align_corners_mode": align_corners_mode_ui, "recompute_scale_factor_mode": recompute_scale_factor_mode_ui, "resolution_choice": resolution_choice, "apply_resolution": apply_resolution, "adaptive_by_resolution": adaptive_by_resolution, "adaptive_profile": adaptive_profile, }) self.step_limit = 0 # Применить выбранное разрешение if apply_resolution: wh = parse_resolution_label(resolution_choice) if wh: p.width, p.height = wh # Выключено — снять коллбеки и обёртки if not enable or self.disable: try: script_callbacks.remove_current_script_callbacks() except Exception: pass self._cb_registered = False try: KohyaHiresFix._unwrap_all(p.sd_model.model.diffusion_model) except Exception: pass return # Адаптация значений под фактическое разрешение use_s1, use_s2 = s1, s2 use_d1, use_d2 = d1, d2 use_down, use_up = downscale, upscale if adaptive_by_resolution: try: use_s1, use_s2, use_d1, use_d2, use_down, use_up = _compute_adaptive_params( int(p.width), int(p.height), adaptive_profile, s1, s2, d1, d2, downscale, upscale, keep_unitary_product, ) except Exception: pass if use_s1 > use_s2: use_s2 = use_s1 model = p.sd_model.model.diffusion_model max_inp = len(getattr(model, "input_blocks", [])) - 1 if max_inp < 0: return d1_idx = max(0, min(int(use_d1) - 1, max_inp)) d2_idx = max(0, min(int(use_d2) - 1, max_inp)) scaler_mode = _safe_mode(scaler) # Объединить пары по глубине combined: Dict[int, float] = {} for s_stop, d_idx in ((float(use_s1), d1_idx), (float(use_s2), d2_idx)): combined[d_idx] = max(combined.get(d_idx, 0.0), s_stop) # Диагностика (однократно) _diag_printed = {"done": False} def denoiser_callback(params: script_callbacks.CFGDenoiserParams): if params.sampling_step < self.step_limit: return total = max(1, int(params.total_sampling_steps)) if not _diag_printed["done"]: try: nin = len(getattr(model, "input_blocks", [])) nout = len(getattr(model, "output_blocks", [])) print(f"[KohyaHiresFix] input_blocks={nin}, output_blocks={nout}, early_out={early_out}, " f"smooth={smooth_scaling}, keep1={keep_unitary_product}, scaler={scaler_mode}") for d_idx, s_stop in combined.items(): oi = KohyaHiresFix._map_output_index(model, d_idx, early_out) print(f"[KohyaHiresFix] depth {d_idx} -> out {oi}, stop@{s_stop:.3f}, " f"down={use_down:.3f}, up={use_up:.3f}") finally: _diag_printed["done"] = True for d_idx, s_stop in combined.items(): out_idx = KohyaHiresFix._map_output_index(model, d_idx, early_out) if out_idx is None: # Нет сопряжённого выходного блока — пропускаем эту стадию continue try: if params.sampling_step < total * s_stop: if not isinstance(model.input_blocks[d_idx], Scaler): model.input_blocks[d_idx] = Scaler( use_down, model.input_blocks[d_idx], scaler_mode, align_mode, recompute_mode ) model.output_blocks[out_idx] = Scaler( use_up, model.output_blocks[out_idx], scaler_mode, align_mode, recompute_mode ) if smooth_scaling: # t в [0..1], опционально smoothstep ratio = params.sampling_step / (total * s_stop) ratio = float(max(0.0, min(1.0, ratio))) if (smoothing_curve or "").lower().startswith("smooth"): # smoothstep: t^2 * (3 - 2t) ratio = ratio * ratio * (3.0 - 2.0 * ratio) cur_down = min((1.0 - use_down) * ratio + use_down, 1.0) model.input_blocks[d_idx].scale = cur_down if keep_unitary_product: cur_up = 1.0 / max(1e-6, cur_down) else: cur_up = use_up * (use_down / max(1e-6, cur_down)) model.output_blocks[out_idx].scale = cur_up else: if isinstance(model.input_blocks[d_idx], Scaler): model.input_blocks[d_idx] = model.input_blocks[d_idx].block if isinstance(model.output_blocks[out_idx], Scaler): model.output_blocks[out_idx] = model.output_blocks[out_idx].block except Exception as e: # Фатальная ошибка: раскрыть обёртки, снять коллбэк и отключить расширение до следующего запуска try: KohyaHiresFix._unwrap_all(model) finally: try: script_callbacks.remove_current_script_callbacks() except Exception: pass self._cb_registered = False self.disable = True print(f"[KohyaHiresFix] Отключено после ошибки: {type(e).__name__}: {e}") return # выходим из коллбэка немедленно self.step_limit = int(params.sampling_step) if only_one_pass else 0 # Обновить коллбек if self._cb_registered: try: script_callbacks.remove_current_script_callbacks() except Exception: pass self._cb_registered = False script_callbacks.on_cfg_denoiser(denoiser_callback) self._cb_registered = True # Инфотекст: фактические значения + выбранные режимы parameters = { "DSHF_res": resolution_choice, "DSHF_apply_res": apply_resolution, "DSHF_s1": use_s1, "DSHF_d1": use_d1, "DSHF_s2": use_s2, "DSHF_d2": use_d2, "DSHF_scaler": scaler_mode, "DSHF_down": use_down, "DSHF_up": use_up, "DSHF_smooth": smooth_scaling, "DSHF_smooth_curve": smoothing_curve, "DSHF_early": early_out, "DSHF_one": only_one_pass, "DSHF_keep1": keep_unitary_product, "DSHF_align": align_corners_mode_ui, "DSHF_recompute": recompute_scale_factor_mode_ui, "DSHF_adapt": adaptive_by_resolution, "DSHF_adapt_profile": adaptive_profile, } for k, v in parameters.items(): p.extra_generation_params[k] = v def postprocess(self, p, processed, *args): try: KohyaHiresFix._unwrap_all(p.sd_model.model.diffusion_model) finally: try: _atomic_save_yaml(CONFIG_PATH, OmegaConf.to_container(self.config, resolve=True) or {}) except Exception: pass self._cb_registered = False def process_batch(self, p, *args, **kwargs): self.step_limit = 0