# Deep Shrink Hires.fix (RU++ v2.1 UI Toggles, fixed) # Совместимость: Python 3.10+, PyTorch >= 2.0, AUTOMATIC1111 WebUI >= 1.9 from dataclasses import dataclass from typing import List, Optional, Dict, Any import json, os import gradio as gr import torch import torch.nn.functional as F import modules.devices as devices import modules.scripts as scripts import modules.script_callbacks as script_callbacks import modules.sd_unet as sd_unet import modules.shared as shared from ldm.modules.attention import SpatialTransformer # noqa: F401 from ldm.modules.diffusionmodules.openaimodel import Upsample, Downsample, ResBlock from ldm.modules.diffusionmodules.util import timestep_embedding # -------------------------- Утилиты -------------------------- def _to_scalar(x) -> float: if isinstance(x, torch.Tensor): return float(x.item()) return float(x) def _clamp(v: float, lo: float, hi: float) -> float: return max(lo, min(hi, v)) def _safe_size(h: torch.Tensor, scale_factor: float) -> tuple[int, int]: h_in, w_in = h.shape[-2], h.shape[-1] h_out = max(2, int(round(h_in * scale_factor))) w_out = max(2, int(round(w_in * scale_factor))) if h_out < DSHF.min_feature_size or w_out < DSHF.min_feature_size: return h_in, w_in return h_out, w_out def _interpolate(img: torch.Tensor, size: tuple[int, int]) -> torch.Tensor: if size == img.shape[-2:]: return img dtype = img.dtype mode = DSHF.interp_method antialias = bool(DSHF.interp_antialias) try: out = F.interpolate( img.float(), size=size, mode=mode, align_corners=False if mode in ("bilinear","bicubic") else None, antialias=antialias if mode in ("bilinear","bicubic") else False ) except TypeError: out = F.interpolate( img.float(), size=size, mode=mode, align_corners=False if mode in ("bilinear","bicubic") else None ) return out.to(dtype) def _resize(h: torch.Tensor, scale_factor: float) -> torch.Tensor: if scale_factor == 1.0: return h return _interpolate(h, _safe_size(h, scale_factor)) def _parse_number_list(text: str, as_int: bool = False) -> List[float]: if text is None: raise ValueError("Пустая строка параметров.") values: List[float] = [] for raw in str(text).replace("\n"," ").split(";"): s = raw.strip() if not s: continue if "/" in s: a,b = s.split("/",1); val = float(a.strip())/float(b.strip()) else: val = float(s) values.append(val) if not values: raise ValueError("Не найдено ни одного валидного значения.") return [int(round(v)) for v in values] if as_int else values def _get_or_last(seq: List[float], index: int, default: float) -> float: if not seq: return default return seq[index] if index < len(seq) else seq[-1] def _preset_path() -> str: return os.path.join(os.path.dirname(__file__), "dshf_presets.json") def _load_all_presets() -> Dict[str, Any]: p = _preset_path() if not os.path.exists(p): return {"version": 2, "presets": {}} try: with open(p, "r", encoding="utf-8") as f: data = json.load(f) return data if "presets" in data else {"version": 2, "presets": {}} except Exception: return {"version": 2, "presets": {}} def _save_all_presets(data: Dict[str, Any]) -> None: p = _preset_path() try: with open(p, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) except Exception as e: print(f"[DSHF] Не удалось сохранить пресеты: {e}") def _build_profile_dict() -> Dict[str, Any]: return { "version": 2, "actions": [{"enable": a.enable,"timestep":a.timestep,"depth":a.depth,"scale":a.scale} for a in DSHF.dshf_actions], "experimental_enable": DSHF.enableExperimental, "experimental": [{ "enable": e.enable,"timestep":e.timestep,"scales":e.scales, "in_multipliers":e.in_multipliers,"out_multipliers":e.out_multipliers, "dilations":e.dilations,"cfg_scale_scale":e.cfg_scale_scale } for e in DSHF.dshf_experimental_actions], "curve": { "enable": DSHF.enable_curve,"type": DSHF.curve_type, "t_start": DSHF.curve_t_start,"t_end": DSHF.curve_t_end, "scale_start": DSHF.curve_scale_start,"scale_end": DSHF.curve_scale_end, "alpha": DSHF.curve_alpha,"min_feature": DSHF.min_feature_size, "auto_end_enable": DSHF.auto_end_enable,"auto_end_strength": DSHF.auto_end_strength }, "runtime": { "timestep_policy": DSHF.timestep_policy,"interp_method": DSHF.interp_method, "interp_antialias": DSHF.interp_antialias,"channels_last": DSHF.channels_last, "enable_soft_clamp": DSHF.enable_soft_clamp,"soft_clamp_beta": DSHF.soft_clamp_beta, "min_depth": DSHF.min_depth,"max_depth": DSHF.max_depth } } def _apply_profile_dict(data: Dict[str, Any]) -> None: try: DSHF.dshf_actions.clear() for a in data.get("actions", []): DSHF.dshf_actions.append(DSHFAction(bool(a.get("enable",True)),float(a.get("timestep",0)), int(a.get("depth",0)),float(a.get("scale",1.0)))) DSHF.enableExperimental = bool(data.get("experimental_enable", False)) DSHF.dshf_experimental_actions.clear() for e in data.get("experimental", []): DSHF.dshf_experimental_actions.append(DSHFExperimentalAction( bool(e.get("enable",False)), float(e.get("timestep",0)), list(map(float, e.get("scales",[]))), list(map(float, e.get("in_multipliers",[]))), list(map(float, e.get("out_multipliers",[]))), list(map(int, e.get("dilations",[]))), list(map(float, e.get("cfg_scale_scale",[]))), )) c = data.get("curve", {}) DSHF.enable_curve = bool(c.get("enable", False)) DSHF.curve_type = str(c.get("type","linear")) DSHF.curve_t_start = float(c.get("t_start",800)) DSHF.curve_t_end = float(c.get("t_end",200)) DSHF.curve_scale_start = float(c.get("scale_start",1.0)) DSHF.curve_scale_end = float(c.get("scale_end",1.0)) DSHF.curve_alpha = float(_clamp(float(c.get("alpha",0.5)),0.0,1.0)) DSHF.min_feature_size = int(_clamp(float(c.get("min_feature",8)),2,256)) DSHF.auto_end_enable = bool(c.get("auto_end_enable", False)) DSHF.auto_end_strength = float(_clamp(float(c.get("auto_end_strength",0.35)),0.0,1.0)) r = data.get("runtime", {}) DSHF.timestep_policy = str(r.get("timestep_policy", DSHF.timestep_policy)) DSHF.interp_method = str(r.get("interp_method", DSHF.interp_method)) DSHF.interp_antialias = bool(r.get("interp_antialias", DSHF.interp_antialias)) DSHF.channels_last = bool(r.get("channels_last", DSHF.channels_last)) DSHF.enable_soft_clamp = bool(r.get("enable_soft_clamp", DSHF.enable_soft_clamp)) DSHF.soft_clamp_beta = float(_clamp(float(r.get("soft_clamp_beta", DSHF.soft_clamp_beta)),0.0,5.0)) DSHF.min_depth = int(_clamp(int(r.get("min_depth", DSHF.min_depth)),0,99)) DSHF.max_depth = int(_clamp(int(r.get("max_depth", DSHF.max_depth)),0,99)) except Exception as e: print(f"[DSHF] Ошибка применения профиля: {e}") # -------------------------- Структуры данных -------------------------- @dataclass class DSHFAction: enable: bool; timestep: float; depth: int; scale: float @dataclass class DSHFExperimentalAction: enable: bool; timestep: float scales: List[float]; in_multipliers: List[float]; out_multipliers: List[float] dilations: List[int]; cfg_scale_scale: List[float] # -------------------------- Основной скрипт -------------------------- class DSHF(scripts.Script): dshf_actions: List[DSHFAction] = [] enableExperimental: bool = False dshf_experimental_actions: List[DSHFExperimentalAction] = [] currentBlock: int = 0 currentConv: int = 0 currentTimestep: float = 1000.0 enable_curve: bool = False curve_type: str = "linear" curve_t_start: float = 800.0 curve_t_end: float = 200.0 curve_scale_start: float = 1.0 curve_scale_end: float = 1.0 curve_alpha: float = 0.5 min_feature_size: int = 8 auto_end_enable: bool = False auto_end_strength: float = 0.35 timestep_policy: str = "min" interp_method: str = "bicubic" interp_antialias: bool = True channels_last: bool = False enable_soft_clamp: bool = False soft_clamp_beta: float = 1.5 min_depth: int = 0 max_depth: int = 999 def title(self): return "Deep Shrink Hires.fix (RU++ v2.1)" def show(self, is_img2img): return scripts.AlwaysVisible @staticmethod def _active_experimental() -> Optional[DSHFExperimentalAction]: if not DSHF.enableExperimental: return None ts = DSHF.currentTimestep for a in DSHF.dshf_experimental_actions: if a.enable and a.timestep <= ts: return a return None @staticmethod def _curve_weight(ts: float) -> Optional[float]: if not DSHF.enable_curve: return None t0, t1 = DSHF.curve_t_start, DSHF.curve_t_end if t0 == t1: return DSHF.curve_scale_end x = _clamp((ts - t1) / (t0 - t1), 0.0, 1.0) if DSHF.curve_type == "linear": w = x elif DSHF.curve_type == "cosine": w = 0.5 - 0.5 * torch.cos(torch.tensor(x) * torch.pi).item() else: w = 1.0 / (1.0 + torch.exp(torch.tensor(-10.0 * (x - 0.5)))).item() s0, s1 = float(DSHF.curve_scale_start), float(DSHF.curve_scale_end) return _clamp(s0 + (s1 - s0) * w, 0.25, 4.0) @staticmethod def _block_scale(depth: int) -> Optional[float]: if depth < DSHF.min_depth or depth > DSHF.max_depth: return None ts = DSHF.currentTimestep rule_scale = None for a in DSHF.dshf_actions: if a.enable and a.depth == depth and a.timestep <= ts: rule_scale = a.scale; break curve_scale = DSHF._curve_weight(ts) if rule_scale is None and curve_scale is None: return None if rule_scale is None: return curve_scale if curve_scale is None: return rule_scale return _clamp(rule_scale * curve_scale, 0.25, 4.0) @staticmethod def _auto_scale_end(p) -> Optional[float]: if not DSHF.auto_end_enable or not DSHF.enable_curve: return None try: bw,bh = int(getattr(p,"width",0)), int(getattr(p,"height",0)) if bw<=0 or bh<=0: return None tw,th = bw,bh if getattr(p,"enable_hr",False): hrx,hry = int(getattr(p,"hr_resize_x",0)), int(getattr(p,"hr_resize_y",0)) hrs = float(getattr(p,"hr_scale",0.0) or 0.0) if hrx>0 and hry>0: tw,th = hrx,hry elif hrs>0.0: tw,th = int(round(bw*hrs)), int(round(bh*hrs)) r = ((max(1,tw*th))/max(1,bw*bh))**0.5 if r<=1.0: return None return _clamp(1.0 + float(_clamp(DSHF.auto_end_strength,0.0,1.0))*(r-1.0), 1.0, 1.7) except Exception: return None @staticmethod def _pick_timestep_scalar(timesteps: torch.Tensor) -> float: pol = DSHF.timestep_policy vals = timesteps.detach().float() if pol=="first": return _to_scalar(vals[0]) if pol=="max": return float(vals.max().item()) if pol=="mean": return float(vals.mean().item()) return float(vals.min().item()) @staticmethod def _soft_clamp(h: torch.Tensor) -> torch.Tensor: if not DSHF.enable_soft_clamp: return h beta = float(DSHF.soft_clamp_beta) if beta<=0: return h mean = h.mean(dim=(2,3), keepdim=True); std = h.std(dim=(2,3), keepdim=True)+1e-6 limit = mean + std*beta return torch.minimum(torch.maximum(h, -limit), limit) # --- UI --- def ui(self, is_img2img): presets = _load_all_presets().get("presets", {}) preset_names = sorted(list(presets.keys())) def toggle(v): return gr.update(visible=bool(v)) with gr.Tabs(): # ===== Настройки ===== with gr.TabItem("Настройки"): Enable_Ext = gr.Checkbox(value=True, label="Включить расширение") # Основные пороги with gr.Accordion(label="Основные пороги (1–2)", open=False): En_Main = gr.Checkbox(value=True, label="Включить секцию") with gr.Group(visible=True) as MainGrp: with gr.Row(): Enable_1 = gr.Checkbox(value=True, label="Включить правило 1") Timestep_1 = gr.Number(value=625, label="Timestep 1") Depth_1 = gr.Number(value=3, label="Глубина блока 1", precision=0) Scale_1 = gr.Number(value=2.0, label="Коэффициент масштаба 1") with gr.Row(): Enable_2 = gr.Checkbox(value=True, label="Включить правило 2") Timestep_2 = gr.Number(value=0, label="Timestep 2") Depth_2 = gr.Number(value=3, label="Глубина блока 2", precision=0) Scale_2 = gr.Number(value=2.0, label="Коэффициент масштаба 2") En_Main.change(toggle, En_Main, MainGrp) # Расширенные пороги with gr.Accordion(label="Расширенные пороги (3–8)", open=False): En_Adv = gr.Checkbox(value=False, label="Включить секцию") with gr.Group(visible=False) as AdvGrp: rows = [] defaults = [(False,900,3,2.0),(False,650,3,2.0),(False,900,3,2.0), (False,650,3,2.0),(False,900,3,2.0),(False,650,3,2.0)] for idx,(en,ts,dp,sc) in enumerate(defaults, start=3): with gr.Row(): rows.append(( gr.Checkbox(value=en, label=f"Включить правило {idx}"), gr.Number(value=ts, label=f"Timestep {idx}"), gr.Number(value=dp, label=f"Глубина блока {idx}", precision=0), gr.Number(value=sc, label=f"Коэффициент масштаба {idx}") )) En_Adv.change(toggle, En_Adv, AdvGrp) # Экспериментальные with gr.Accordion(label="Экспериментальные (масштабы/дилатации/множители)", open=False): Enable_Experimental = gr.Checkbox(value=False, label="Включить секцию") with gr.Group(visible=False) as ExpGrp: def block(prefix, ts_default): with gr.Row(): en = gr.Checkbox(value=True, label=f"{prefix}: включить набор") ts = gr.Number(value=ts_default, label=f"{prefix}: timestep") with gr.Row(): sc = gr.Textbox(value="1; " * 52 + "1", label=f"{prefix}: масштабы (по свёрткам)", lines=2) with gr.Row(): cfg = gr.Textbox(value="1;1;1; 1;1;1; 1;1;1; 1;1;1; 1; 1;1;1; 1;1;1; 1;1;1; 1;1;1", label=f"{prefix}: множители CFG-scale") dil = gr.Textbox(value="1; " * 52 + "1", label=f"{prefix}: дилатации (по свёрткам)", lines=2) with gr.Row(): pre = gr.Textbox(value="1; " * 24 + "1", label=f"{prefix}: входные умножители (по блокам)") post = gr.Textbox(value="1; " * 24 + "1", label=f"{prefix}: выходные умножители (по блокам)") return en, ts, sc, pre, post, dil, cfg (Enable_Experimental_1, Timestep_Experimental_1, Scale_Experimental_1, Premultiplier_Experimental_1, Postmultiplier_Experimental_1, Dilation_Experimental_1, CFG_Scale_Scale_Experimental_1) = block("Набор 1", 625) (Enable_Experimental_2, Timestep_Experimental_2, Scale_Experimental_2, Premultiplier_Experimental_2, Postmultiplier_Experimental_2, Dilation_Experimental_2, CFG_Scale_Scale_Experimental_2) = block("Набор 2", 0) (Enable_Experimental_3, Timestep_Experimental_3, Scale_Experimental_3, Premultiplier_Experimental_3, Postmultiplier_Experimental_3, Dilation_Experimental_3, CFG_Scale_Scale_Experimental_3) = block("Набор 3", 750) (Enable_Experimental_4, Timestep_Experimental_4, Scale_Experimental_4, Premultiplier_Experimental_4, Postmultiplier_Experimental_4, Dilation_Experimental_4, CFG_Scale_Scale_Experimental_4) = block("Набор 4", 750) Enable_Experimental.change(toggle, Enable_Experimental, ExpGrp) # Кривая with gr.Accordion(label="Глобальная кривая масштаба", open=False): Enable_Curve = gr.Checkbox(value=False, label="Включить секцию") with gr.Group(visible=False) as CurveGrp: Curve_Type = gr.Dropdown(choices=["linear","cosine","sigmoid"], value="linear", label="Тип кривой") with gr.Row(): Curve_t_start = gr.Number(value=800, label="t_start") Curve_t_end = gr.Number(value=200, label="t_end") with gr.Row(): Curve_scale_start = gr.Number(value=1.0, label="scale_start") Curve_scale_end = gr.Number(value=1.0, label="scale_end") Curve_alpha = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="alpha (компенсация)") Min_feature = gr.Slider(2, 64, value=8, step=1, label="Минимальный размер фичей") with gr.Row(): Auto_end_enable = gr.Checkbox(value=False, label="Автоподбор scale_end") Auto_end_strength = gr.Slider(0.0, 1.0, value=0.35, step=0.05, label="Сила автоподбора") Enable_Curve.change(toggle, Enable_Curve, CurveGrp) # Импорт with gr.Accordion(label="Импорт профиля (JSON)", open=False): En_Import = gr.Checkbox(value=False, label="Включить секцию") with gr.Group(visible=False) as ImportGrp: Use_Import = gr.Checkbox(value=False, label="Применить JSON ниже") Json_Profile = gr.Textbox(value="", lines=6, label="JSON: actions/experimental/curve/runtime") En_Import.change(toggle, En_Import, ImportGrp) # Пресеты with gr.Accordion(label="Пресеты", open=False): En_Presets = gr.Checkbox(value=False, label="Включить секцию") with gr.Group(visible=False) as PresetGrp: Preset_Apply = gr.Checkbox(value=False, label="Выполнить действие при генерации") Preset_Action = gr.Dropdown(choices=["Сохранить","Загрузить"], value="Загрузить", label="Действие") with gr.Row(): Preset_Name = gr.Textbox(value="", label="Имя пресета") Preset_Existing = gr.Dropdown(choices=preset_names or [""], value=(preset_names[0] if preset_names else ""), label="Выбрать существующий") gr.Markdown("Подсказка: при «Загрузить» используется поле «Имя пресета», если оно заполнено.") En_Presets.change(toggle, En_Presets, PresetGrp) # ===== Выполнение ===== with gr.TabItem("Выполнение"): En_Runtime = gr.Checkbox(value=True, label="Включить секцию") with gr.Group(visible=True) as RuntimeGrp: with gr.Row(): Timestep_Policy = gr.Dropdown(choices=["first","min","max","mean"], value="min", label="Политика timestep") Interp_Method = gr.Dropdown(choices=["nearest","bilinear","bicubic","area"], value="bicubic", label="Интерполяция") with gr.Row(): Interp_Antialias = gr.Checkbox(value=True, label="Антиалиасинг (bilinear/bicubic)") Channels_Last = gr.Checkbox(value=False, label="Оптимизация channels_last") with gr.Row(): Enable_Soft_Clamp = gr.Checkbox(value=False, label="Мягкий клип амплитуды") Soft_Clamp_Beta = gr.Slider(0.0, 5.0, value=1.5, step=0.1, label="beta (mean±beta·std)") with gr.Row(): Min_Depth = gr.Number(value=0, label="Мин. глубина", precision=0) Max_Depth = gr.Number(value=999, label="Макс. глубина", precision=0) En_Runtime.change(toggle, En_Runtime, RuntimeGrp) # ===== Справка ===== with gr.TabItem("Справка"): gr.Markdown(""" **Глобальные тумблеры** у каждой секции управляют и логикой, и показом. Если секция выключена — её параметры игнорируются в `process()`. """) flat = [Enable_Ext, En_Main, Enable_1, Timestep_1, Depth_1, Scale_1, Enable_2, Timestep_2, Depth_2, Scale_2, En_Adv] for en, ts, dp, sc in rows: flat += [en, ts, dp, sc] flat += [ Enable_Experimental, Enable_Experimental_1, Timestep_Experimental_1, Scale_Experimental_1, Premultiplier_Experimental_1, Postmultiplier_Experimental_1, Dilation_Experimental_1, CFG_Scale_Scale_Experimental_1, Enable_Experimental_2, Timestep_Experimental_2, Scale_Experimental_2, Premultiplier_Experimental_2, Postmultiplier_Experimental_2, Dilation_Experimental_2, CFG_Scale_Scale_Experimental_2, Enable_Experimental_3, Timestep_Experimental_3, Scale_Experimental_3, Premultiplier_Experimental_3, Postmultiplier_Experimental_3, Dilation_Experimental_3, CFG_Scale_Scale_Experimental_3, Enable_Experimental_4, Timestep_Experimental_4, Scale_Experimental_4, Premultiplier_Experimental_4, Postmultiplier_Experimental_4, Dilation_Experimental_4, CFG_Scale_Scale_Experimental_4, Enable_Curve, Curve_Type, Curve_t_start, Curve_t_end, Curve_scale_start, Curve_scale_end, Curve_alpha, Min_feature, Auto_end_enable, Auto_end_strength, En_Import, Use_Import, Json_Profile, En_Presets, Preset_Apply, Preset_Action, Preset_Name, Preset_Existing, En_Runtime, Timestep_Policy, Interp_Method, Interp_Antialias, Channels_Last, Enable_Soft_Clamp, Soft_Clamp_Beta, Min_Depth, Max_Depth ] return flat def process(self, p, *args): if not isinstance(sd_unet.current_unet, DSHF.DeepShrinkHiresFixUNet): return it = iter(args) def nxt(): return next(it) enable_ext = bool(nxt()) if not enable_ext: return # Основные пороги en_main = bool(nxt()) base_rules = [] for _ in range(2): base_rules.append(( bool(nxt()), _to_scalar(nxt()), int(_to_scalar(nxt())), float(_to_scalar(nxt())) )) # Расширенные пороги en_adv = bool(nxt()) adv_rules = [] for _ in range(6): adv_rules.append(( bool(nxt()), _to_scalar(nxt()), int(_to_scalar(nxt())), float(_to_scalar(nxt())) )) DSHF.dshf_actions.clear() rules = (base_rules if en_main else [(False,0,0,1.0)]*2) + (adv_rules if en_adv else [(False,0,0,1.0)]*6) for (en,ts,dp,sc) in rules: DSHF.dshf_actions.append(DSHFAction(bool(en), float(ts), int(dp), float(sc))) # Экспериментальные DSHF.enableExperimental = bool(nxt()) exp_sets = [] for _ in range(4): en = bool(nxt()); ts = _to_scalar(nxt()) sc = _parse_number_list(str(nxt()), as_int=False) pre = _parse_number_list(str(nxt()), as_int=False) post = _parse_number_list(str(nxt()), as_int=False) dil = _parse_number_list(str(nxt()), as_int=True) cfg = _parse_number_list(str(nxt()), as_int=False) exp_sets.append(DSHFExperimentalAction(en, ts, sc, pre, post, dil, cfg)) DSHF.dshf_experimental_actions = exp_sets if DSHF.enableExperimental else [] # Кривая DSHF.enable_curve = bool(nxt()) curve_type = str(nxt()); t0 = _to_scalar(nxt()); t1 = _to_scalar(nxt()) s0 = float(_to_scalar(nxt())); s1 = float(_to_scalar(nxt())) alpha = float(_clamp(_to_scalar(nxt()),0.0,1.0)) minfeat = int(_clamp(_to_scalar(nxt()),2,256)) auto_en = bool(nxt()); auto_k = float(_clamp(_to_scalar(nxt()),0.0,1.0)) if DSHF.enable_curve: DSHF.curve_type, DSHF.curve_t_start, DSHF.curve_t_end = curve_type, t0, t1 DSHF.curve_scale_start, DSHF.curve_scale_end = s0, s1 DSHF.curve_alpha, DSHF.min_feature_size = alpha, minfeat DSHF.auto_end_enable, DSHF.auto_end_strength = auto_en, auto_k else: DSHF.auto_end_enable = False # Импорт en_import = bool(nxt()) use_import = bool(nxt()); json_text = str(nxt() or "").strip() if en_import and use_import and json_text: try: _apply_profile_dict(json.loads(json_text)) except Exception as e: print(f"[DSHF] Ошибка JSON-профиля: {e}") # Пресеты en_preset = bool(nxt()) if en_preset: preset_apply = bool(nxt()); action = str(nxt() or ""); name = str(nxt() or "").strip(); existing = str(nxt() or "").strip() if preset_apply: store = _load_all_presets(); bag = store.get("presets", {}) if action == "Сохранить": key = name or existing if key: bag[key] = _build_profile_dict(); store["presets"] = bag; _save_all_presets(store) print(f"[DSHF] Пресет сохранён: '{key}'") else: key = name or existing prof = bag.get(key) if prof: _apply_profile_dict(prof); print(f"[DSHF] Пресет загружен: '{key}'") else: _ = nxt(); _ = nxt(); _ = nxt(); _ = nxt() # Выполнение en_run = bool(nxt()) pol = str(nxt()); im = str(nxt()); aa = bool(nxt()); chlast = bool(nxt()) sclamp = bool(nxt()); beta = float(_clamp(_to_scalar(nxt()),0.0,5.0)) mind = int(_clamp(_to_scalar(nxt()),0,99)); maxd = int(_clamp(_to_scalar(nxt()),0,99)) if en_run: DSHF.timestep_policy, DSHF.interp_method, DSHF.interp_antialias = pol, im, aa DSHF.channels_last, DSHF.enable_soft_clamp, DSHF.soft_clamp_beta = chlast, sclamp, beta DSHF.min_depth, DSHF.max_depth = mind, maxd auto = self._auto_scale_end(p) if auto is not None: DSHF.curve_scale_end = float(auto) # ---------------- Обёртки Conv2d ---------------- class DSHF_Scale(torch.nn.Module): def __init__(self, conv2d_ref: List[torch.nn.Conv2d]): super().__init__(); self.conv2d_ref = conv2d_ref def forward(self, h: torch.Tensor): exp = DSHF._active_experimental() if exp is not None: idx = DSHF.currentConv h = _resize(h, 1.0/_get_or_last(exp.scales, idx, 1.0)) conv = self.conv2d_ref[0] k = conv.kernel_size if isinstance(conv.kernel_size, tuple) else (conv.kernel_size, conv.kernel_size) if max(k)>1: dil = int(_get_or_last(exp.dilations, idx, 1)); conv.dilation = (dil,dil); conv.padding = (dil,dil) else: conv.dilation = (1,1); conv.padding = (0,0) return h class DSHF_Unscale(torch.nn.Module): def forward(self, h: torch.Tensor): exp = DSHF._active_experimental() if exp is not None: idx = DSHF.currentConv s = _get_or_last(exp.scales, idx, 1.0) if s != 1.0: h = _resize(h, s) if DSHF.curve_alpha != 0.0: h = h * (s ** DSHF.curve_alpha) h = DSHF._soft_clamp(h); DSHF.currentConv += 1; return h class DSHF_InMul(torch.nn.Module): def forward(self, h: torch.Tensor): exp = DSHF._active_experimental() if exp is not None: mul = _get_or_last(exp.in_multipliers, DSHF.currentBlock, 1.0) if mul != 1.0: return h * mul return h class DSHF_OutMul(torch.nn.Module): def forward(self, h: torch.Tensor): exp = DSHF._active_experimental() if exp is not None: mul = _get_or_last(exp.out_multipliers, DSHF.currentBlock, 1.0) if mul != 1.0: h = h * mul return DSHF._soft_clamp(h) # --------------- Подменённый U-Net --------------- class DeepShrinkHiresFixUNet(sd_unet.SdUnet): def __init__(self, _model): super().__init__(); self.model = _model.to(devices.device) for i, ib in enumerate(self.model.input_blocks): for j, layer in enumerate(ib): if isinstance(layer, ResBlock): for k, il in enumerate(layer.in_layers): if isinstance(il, torch.nn.Conv2d): self.model.input_blocks[i][j].in_layers[k] = torch.nn.Sequential(DSHF.DSHF_Scale([il]), il, DSHF.DSHF_Unscale(), DSHF.DSHF_InMul()) for k, ol in enumerate(layer.out_layers): if isinstance(ol, torch.nn.Conv2d): self.model.input_blocks[i][j].out_layers[k] = torch.nn.Sequential(DSHF.DSHF_Scale([ol]), ol, DSHF.DSHF_Unscale(), DSHF.DSHF_OutMul()) else: if isinstance(layer, torch.nn.Conv2d): self.model.input_blocks[i][j] = torch.nn.Sequential(DSHF.DSHF_Scale([layer]), layer, DSHF.DSHF_Unscale()) if isinstance(layer, Downsample): layer.op = torch.nn.Sequential(DSHF.DSHF_Scale([layer.op]), layer.op, DSHF.DSHF_Unscale()) if isinstance(layer, Upsample): layer.conv = torch.nn.Sequential(DSHF.DSHF_Scale([layer.conv]), layer.conv, DSHF.DSHF_Unscale()) for j, layer in enumerate(self.model.middle_block): if isinstance(layer, ResBlock): for k, il in enumerate(layer.in_layers): if isinstance(il, torch.nn.Conv2d): self.model.middle_block[j].in_layers[k] = torch.nn.Sequential(DSHF.DSHF_Scale([il]), il, DSHF.DSHF_Unscale(), DSHF.DSHF_InMul()) for k, ol in enumerate(layer.out_layers): if isinstance(ol, torch.nn.Conv2d): self.model.middle_block[j].out_layers[k] = torch.nn.Sequential(DSHF.DSHF_Scale([ol]), ol, DSHF.DSHF_Unscale(), DSHF.DSHF_OutMul()) else: if isinstance(layer, torch.nn.Conv2d): self.model.middle_block[j] = torch.nn.Sequential(DSHF.DSHF_Scale([layer]), layer, DSHF.DSHF_Unscale()) if isinstance(layer, Downsample): layer.op = torch.nn.Sequential(DSHF.DSHF_Scale([layer.op]), layer.op, DSHF.DSHF_Unscale()) if isinstance(layer, Upsample): layer.conv = torch.nn.Sequential(DSHF.DSHF_Scale([layer.conv]), layer.conv, DSHF.DSHF_Unscale()) for i, ob in enumerate(self.model.output_blocks): for j, layer in enumerate(ob): if isinstance(layer, ResBlock): for k, il in enumerate(layer.in_layers): if isinstance(il, torch.nn.Conv2d): self.model.output_blocks[i][j].in_layers[k] = torch.nn.Sequential(DSHF.DSHF_Scale([il]), il, DSHF.DSHF_Unscale(), DSHF.DSHF_InMul()) for k, ol in enumerate(layer.out_layers): if isinstance(ol, torch.nn.Conv2d): self.model.output_blocks[i][j].out_layers[k] = torch.nn.Sequential(DSHF.DSHF_Scale([ol]), ol, DSHF.DSHF_Unscale(), DSHF.DSHF_OutMul()) else: if isinstance(layer, torch.nn.Conv2d): self.model.output_blocks[i][j] = torch.nn.Sequential(DSHF.DSHF_Scale([layer]), layer, DSHF.DSHF_Unscale()) if isinstance(layer, Downsample): layer.op = torch.nn.Sequential(DSHF.DSHF_Scale([layer.op]), layer.op, DSHF.DSHF_Unscale()) if isinstance(layer, Upsample): layer.conv = torch.nn.Sequential(DSHF.DSHF_Scale([layer.conv]), layer.conv, DSHF.DSHF_Unscale()) for i, m in enumerate(self.model.out): if isinstance(m, torch.nn.Conv2d): self.model.out[i] = torch.nn.Sequential(DSHF.DSHF_Scale([m]), m, DSHF.DSHF_Unscale()) def forward(self, x, timesteps, context, y=None, **kwargs): assert (y is not None) == (self.model.num_classes is not None), "must specify y iff class-conditional" if DSHF.channels_last: x = x.contiguous(memory_format=torch.channels_last) hs = [] emb = self.model.time_embed(timestep_embedding(timesteps, self.model.model_channels, repeat_only=False)) if self.model.num_classes is not None: assert y.shape[0]==x.shape[0]; emb = emb + self.model.label_emb(y) h = x.type(self.model.dtype); depth = 0 DSHF.currentBlock = 0; DSHF.currentConv = 0; DSHF.currentTimestep = DSHF._pick_timestep_scalar(timesteps) for module in self.model.input_blocks: context_tmp = context scale = DSHF._block_scale(depth) if scale is not None: h = _resize(h, 1.0/float(scale)) exp = DSHF._active_experimental() if exp is not None: cfg_mul = _get_or_last(exp.cfg_scale_scale, DSHF.currentBlock, 1.0) context_tmp = context * float(cfg_mul) h = module(h, emb, context_tmp); hs.append(h); depth += 1; DSHF.currentBlock += 1 context_tmp = context; scale = DSHF._block_scale(depth) if scale is not None: h = _resize(h, 1.0/float(scale)) exp = DSHF._active_experimental() if exp is not None: cfg_mul = _get_or_last(exp.cfg_scale_scale, DSHF.currentBlock, 1.0) context_tmp = context * float(cfg_mul) h = self.model.middle_block(h, emb, context_tmp) scale = DSHF._block_scale(depth) if scale is not None: h = _resize(h, float(scale)) DSHF.currentBlock += 1 for module in self.model.output_blocks: context_tmp = context exp = DSHF._active_experimental() if exp is not None: cfg_mul = _get_or_last(exp.cfg_scale_scale, DSHF.currentBlock, 1.0) context_tmp = context * float(cfg_mul) depth -= 1; h = torch.cat([h, hs.pop()], dim=1); h = module(h, emb, context_tmp) scale = DSHF._block_scale(depth) if scale is not None: h = _resize(h, float(scale)) DSHF.currentBlock += 1 h = h.type(x.dtype) return self.model.id_predictor(h) if self.model.predict_codebook_ids else self.model.out(h) # Регистрация U-Net DeepShrinkHiresFixUNetOption = sd_unet.SdUnetOption() DeepShrinkHiresFixUNetOption.label = "Deep Shrink Hires.fix" DeepShrinkHiresFixUNetOption.create_unet = lambda: DSHF.DeepShrinkHiresFixUNet(shared.sd_model.model.diffusion_model) script_callbacks.on_list_unets(lambda unets: unets.append(DeepShrinkHiresFixUNetOption))