sdas / DeepShrinkHires.fix /scripts /DeepShrinkHires.fix.py
dikdimon's picture
Upload DeepShrinkHires.fix using SD-Hub
da29f0f verified
# Deep Shrink Hires.fix (RU++ v2.1 UI Toggles, fixed)
# Совместимость: Python 3.10+, PyTorch >= 2.0, AUTOMATIC1111 WebUI >= 1.9
from dataclasses import dataclass
from typing import List, Optional, Dict, Any
import json, os
import gradio as gr
import torch
import torch.nn.functional as F
import modules.devices as devices
import modules.scripts as scripts
import modules.script_callbacks as script_callbacks
import modules.sd_unet as sd_unet
import modules.shared as shared
from ldm.modules.attention import SpatialTransformer # noqa: F401
from ldm.modules.diffusionmodules.openaimodel import Upsample, Downsample, ResBlock
from ldm.modules.diffusionmodules.util import timestep_embedding
# -------------------------- Утилиты --------------------------
def _to_scalar(x) -> float:
if isinstance(x, torch.Tensor): return float(x.item())
return float(x)
def _clamp(v: float, lo: float, hi: float) -> float:
return max(lo, min(hi, v))
def _safe_size(h: torch.Tensor, scale_factor: float) -> tuple[int, int]:
h_in, w_in = h.shape[-2], h.shape[-1]
h_out = max(2, int(round(h_in * scale_factor)))
w_out = max(2, int(round(w_in * scale_factor)))
if h_out < DSHF.min_feature_size or w_out < DSHF.min_feature_size:
return h_in, w_in
return h_out, w_out
def _interpolate(img: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
if size == img.shape[-2:]: return img
dtype = img.dtype
mode = DSHF.interp_method
antialias = bool(DSHF.interp_antialias)
try:
out = F.interpolate(
img.float(), size=size, mode=mode,
align_corners=False if mode in ("bilinear","bicubic") else None,
antialias=antialias if mode in ("bilinear","bicubic") else False
)
except TypeError:
out = F.interpolate(
img.float(), size=size, mode=mode,
align_corners=False if mode in ("bilinear","bicubic") else None
)
return out.to(dtype)
def _resize(h: torch.Tensor, scale_factor: float) -> torch.Tensor:
if scale_factor == 1.0: return h
return _interpolate(h, _safe_size(h, scale_factor))
def _parse_number_list(text: str, as_int: bool = False) -> List[float]:
if text is None: raise ValueError("Пустая строка параметров.")
values: List[float] = []
for raw in str(text).replace("\n"," ").split(";"):
s = raw.strip()
if not s: continue
if "/" in s:
a,b = s.split("/",1); val = float(a.strip())/float(b.strip())
else:
val = float(s)
values.append(val)
if not values: raise ValueError("Не найдено ни одного валидного значения.")
return [int(round(v)) for v in values] if as_int else values
def _get_or_last(seq: List[float], index: int, default: float) -> float:
if not seq: return default
return seq[index] if index < len(seq) else seq[-1]
def _preset_path() -> str:
return os.path.join(os.path.dirname(__file__), "dshf_presets.json")
def _load_all_presets() -> Dict[str, Any]:
p = _preset_path()
if not os.path.exists(p): return {"version": 2, "presets": {}}
try:
with open(p, "r", encoding="utf-8") as f: data = json.load(f)
return data if "presets" in data else {"version": 2, "presets": {}}
except Exception:
return {"version": 2, "presets": {}}
def _save_all_presets(data: Dict[str, Any]) -> None:
p = _preset_path()
try:
with open(p, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"[DSHF] Не удалось сохранить пресеты: {e}")
def _build_profile_dict() -> Dict[str, Any]:
return {
"version": 2,
"actions": [{"enable": a.enable,"timestep":a.timestep,"depth":a.depth,"scale":a.scale} for a in DSHF.dshf_actions],
"experimental_enable": DSHF.enableExperimental,
"experimental": [{
"enable": e.enable,"timestep":e.timestep,"scales":e.scales,
"in_multipliers":e.in_multipliers,"out_multipliers":e.out_multipliers,
"dilations":e.dilations,"cfg_scale_scale":e.cfg_scale_scale
} for e in DSHF.dshf_experimental_actions],
"curve": {
"enable": DSHF.enable_curve,"type": DSHF.curve_type,
"t_start": DSHF.curve_t_start,"t_end": DSHF.curve_t_end,
"scale_start": DSHF.curve_scale_start,"scale_end": DSHF.curve_scale_end,
"alpha": DSHF.curve_alpha,"min_feature": DSHF.min_feature_size,
"auto_end_enable": DSHF.auto_end_enable,"auto_end_strength": DSHF.auto_end_strength
},
"runtime": {
"timestep_policy": DSHF.timestep_policy,"interp_method": DSHF.interp_method,
"interp_antialias": DSHF.interp_antialias,"channels_last": DSHF.channels_last,
"enable_soft_clamp": DSHF.enable_soft_clamp,"soft_clamp_beta": DSHF.soft_clamp_beta,
"min_depth": DSHF.min_depth,"max_depth": DSHF.max_depth
}
}
def _apply_profile_dict(data: Dict[str, Any]) -> None:
try:
DSHF.dshf_actions.clear()
for a in data.get("actions", []):
DSHF.dshf_actions.append(DSHFAction(bool(a.get("enable",True)),float(a.get("timestep",0)),
int(a.get("depth",0)),float(a.get("scale",1.0))))
DSHF.enableExperimental = bool(data.get("experimental_enable", False))
DSHF.dshf_experimental_actions.clear()
for e in data.get("experimental", []):
DSHF.dshf_experimental_actions.append(DSHFExperimentalAction(
bool(e.get("enable",False)), float(e.get("timestep",0)),
list(map(float, e.get("scales",[]))),
list(map(float, e.get("in_multipliers",[]))),
list(map(float, e.get("out_multipliers",[]))),
list(map(int, e.get("dilations",[]))),
list(map(float, e.get("cfg_scale_scale",[]))),
))
c = data.get("curve", {})
DSHF.enable_curve = bool(c.get("enable", False))
DSHF.curve_type = str(c.get("type","linear"))
DSHF.curve_t_start = float(c.get("t_start",800))
DSHF.curve_t_end = float(c.get("t_end",200))
DSHF.curve_scale_start = float(c.get("scale_start",1.0))
DSHF.curve_scale_end = float(c.get("scale_end",1.0))
DSHF.curve_alpha = float(_clamp(float(c.get("alpha",0.5)),0.0,1.0))
DSHF.min_feature_size = int(_clamp(float(c.get("min_feature",8)),2,256))
DSHF.auto_end_enable = bool(c.get("auto_end_enable", False))
DSHF.auto_end_strength = float(_clamp(float(c.get("auto_end_strength",0.35)),0.0,1.0))
r = data.get("runtime", {})
DSHF.timestep_policy = str(r.get("timestep_policy", DSHF.timestep_policy))
DSHF.interp_method = str(r.get("interp_method", DSHF.interp_method))
DSHF.interp_antialias = bool(r.get("interp_antialias", DSHF.interp_antialias))
DSHF.channels_last = bool(r.get("channels_last", DSHF.channels_last))
DSHF.enable_soft_clamp = bool(r.get("enable_soft_clamp", DSHF.enable_soft_clamp))
DSHF.soft_clamp_beta = float(_clamp(float(r.get("soft_clamp_beta", DSHF.soft_clamp_beta)),0.0,5.0))
DSHF.min_depth = int(_clamp(int(r.get("min_depth", DSHF.min_depth)),0,99))
DSHF.max_depth = int(_clamp(int(r.get("max_depth", DSHF.max_depth)),0,99))
except Exception as e:
print(f"[DSHF] Ошибка применения профиля: {e}")
# -------------------------- Структуры данных --------------------------
@dataclass
class DSHFAction:
enable: bool; timestep: float; depth: int; scale: float
@dataclass
class DSHFExperimentalAction:
enable: bool; timestep: float
scales: List[float]; in_multipliers: List[float]; out_multipliers: List[float]
dilations: List[int]; cfg_scale_scale: List[float]
# -------------------------- Основной скрипт --------------------------
class DSHF(scripts.Script):
dshf_actions: List[DSHFAction] = []
enableExperimental: bool = False
dshf_experimental_actions: List[DSHFExperimentalAction] = []
currentBlock: int = 0
currentConv: int = 0
currentTimestep: float = 1000.0
enable_curve: bool = False
curve_type: str = "linear"
curve_t_start: float = 800.0
curve_t_end: float = 200.0
curve_scale_start: float = 1.0
curve_scale_end: float = 1.0
curve_alpha: float = 0.5
min_feature_size: int = 8
auto_end_enable: bool = False
auto_end_strength: float = 0.35
timestep_policy: str = "min"
interp_method: str = "bicubic"
interp_antialias: bool = True
channels_last: bool = False
enable_soft_clamp: bool = False
soft_clamp_beta: float = 1.5
min_depth: int = 0
max_depth: int = 999
def title(self): return "Deep Shrink Hires.fix (RU++ v2.1)"
def show(self, is_img2img): return scripts.AlwaysVisible
@staticmethod
def _active_experimental() -> Optional[DSHFExperimentalAction]:
if not DSHF.enableExperimental: return None
ts = DSHF.currentTimestep
for a in DSHF.dshf_experimental_actions:
if a.enable and a.timestep <= ts: return a
return None
@staticmethod
def _curve_weight(ts: float) -> Optional[float]:
if not DSHF.enable_curve: return None
t0, t1 = DSHF.curve_t_start, DSHF.curve_t_end
if t0 == t1: return DSHF.curve_scale_end
x = _clamp((ts - t1) / (t0 - t1), 0.0, 1.0)
if DSHF.curve_type == "linear": w = x
elif DSHF.curve_type == "cosine": w = 0.5 - 0.5 * torch.cos(torch.tensor(x) * torch.pi).item()
else:
w = 1.0 / (1.0 + torch.exp(torch.tensor(-10.0 * (x - 0.5)))).item()
s0, s1 = float(DSHF.curve_scale_start), float(DSHF.curve_scale_end)
return _clamp(s0 + (s1 - s0) * w, 0.25, 4.0)
@staticmethod
def _block_scale(depth: int) -> Optional[float]:
if depth < DSHF.min_depth or depth > DSHF.max_depth: return None
ts = DSHF.currentTimestep
rule_scale = None
for a in DSHF.dshf_actions:
if a.enable and a.depth == depth and a.timestep <= ts:
rule_scale = a.scale; break
curve_scale = DSHF._curve_weight(ts)
if rule_scale is None and curve_scale is None: return None
if rule_scale is None: return curve_scale
if curve_scale is None: return rule_scale
return _clamp(rule_scale * curve_scale, 0.25, 4.0)
@staticmethod
def _auto_scale_end(p) -> Optional[float]:
if not DSHF.auto_end_enable or not DSHF.enable_curve: return None
try:
bw,bh = int(getattr(p,"width",0)), int(getattr(p,"height",0))
if bw<=0 or bh<=0: return None
tw,th = bw,bh
if getattr(p,"enable_hr",False):
hrx,hry = int(getattr(p,"hr_resize_x",0)), int(getattr(p,"hr_resize_y",0))
hrs = float(getattr(p,"hr_scale",0.0) or 0.0)
if hrx>0 and hry>0: tw,th = hrx,hry
elif hrs>0.0: tw,th = int(round(bw*hrs)), int(round(bh*hrs))
r = ((max(1,tw*th))/max(1,bw*bh))**0.5
if r<=1.0: return None
return _clamp(1.0 + float(_clamp(DSHF.auto_end_strength,0.0,1.0))*(r-1.0), 1.0, 1.7)
except Exception: return None
@staticmethod
def _pick_timestep_scalar(timesteps: torch.Tensor) -> float:
pol = DSHF.timestep_policy
vals = timesteps.detach().float()
if pol=="first": return _to_scalar(vals[0])
if pol=="max": return float(vals.max().item())
if pol=="mean": return float(vals.mean().item())
return float(vals.min().item())
@staticmethod
def _soft_clamp(h: torch.Tensor) -> torch.Tensor:
if not DSHF.enable_soft_clamp: return h
beta = float(DSHF.soft_clamp_beta)
if beta<=0: return h
mean = h.mean(dim=(2,3), keepdim=True); std = h.std(dim=(2,3), keepdim=True)+1e-6
limit = mean + std*beta
return torch.minimum(torch.maximum(h, -limit), limit)
# --- UI ---
def ui(self, is_img2img):
presets = _load_all_presets().get("presets", {})
preset_names = sorted(list(presets.keys()))
def toggle(v): return gr.update(visible=bool(v))
with gr.Tabs():
# ===== Настройки =====
with gr.TabItem("Настройки"):
Enable_Ext = gr.Checkbox(value=True, label="Включить расширение")
# Основные пороги
with gr.Accordion(label="Основные пороги (1–2)", open=False):
En_Main = gr.Checkbox(value=True, label="Включить секцию")
with gr.Group(visible=True) as MainGrp:
with gr.Row():
Enable_1 = gr.Checkbox(value=True, label="Включить правило 1")
Timestep_1 = gr.Number(value=625, label="Timestep 1")
Depth_1 = gr.Number(value=3, label="Глубина блока 1", precision=0)
Scale_1 = gr.Number(value=2.0, label="Коэффициент масштаба 1")
with gr.Row():
Enable_2 = gr.Checkbox(value=True, label="Включить правило 2")
Timestep_2 = gr.Number(value=0, label="Timestep 2")
Depth_2 = gr.Number(value=3, label="Глубина блока 2", precision=0)
Scale_2 = gr.Number(value=2.0, label="Коэффициент масштаба 2")
En_Main.change(toggle, En_Main, MainGrp)
# Расширенные пороги
with gr.Accordion(label="Расширенные пороги (3–8)", open=False):
En_Adv = gr.Checkbox(value=False, label="Включить секцию")
with gr.Group(visible=False) as AdvGrp:
rows = []
defaults = [(False,900,3,2.0),(False,650,3,2.0),(False,900,3,2.0),
(False,650,3,2.0),(False,900,3,2.0),(False,650,3,2.0)]
for idx,(en,ts,dp,sc) in enumerate(defaults, start=3):
with gr.Row():
rows.append((
gr.Checkbox(value=en, label=f"Включить правило {idx}"),
gr.Number(value=ts, label=f"Timestep {idx}"),
gr.Number(value=dp, label=f"Глубина блока {idx}", precision=0),
gr.Number(value=sc, label=f"Коэффициент масштаба {idx}")
))
En_Adv.change(toggle, En_Adv, AdvGrp)
# Экспериментальные
with gr.Accordion(label="Экспериментальные (масштабы/дилатации/множители)", open=False):
Enable_Experimental = gr.Checkbox(value=False, label="Включить секцию")
with gr.Group(visible=False) as ExpGrp:
def block(prefix, ts_default):
with gr.Row():
en = gr.Checkbox(value=True, label=f"{prefix}: включить набор")
ts = gr.Number(value=ts_default, label=f"{prefix}: timestep")
with gr.Row():
sc = gr.Textbox(value="1; " * 52 + "1", label=f"{prefix}: масштабы (по свёрткам)", lines=2)
with gr.Row():
cfg = gr.Textbox(value="1;1;1; 1;1;1; 1;1;1; 1;1;1; 1; 1;1;1; 1;1;1; 1;1;1; 1;1;1",
label=f"{prefix}: множители CFG-scale")
dil = gr.Textbox(value="1; " * 52 + "1", label=f"{prefix}: дилатации (по свёрткам)", lines=2)
with gr.Row():
pre = gr.Textbox(value="1; " * 24 + "1", label=f"{prefix}: входные умножители (по блокам)")
post = gr.Textbox(value="1; " * 24 + "1", label=f"{prefix}: выходные умножители (по блокам)")
return en, ts, sc, pre, post, dil, cfg
(Enable_Experimental_1, Timestep_Experimental_1, Scale_Experimental_1,
Premultiplier_Experimental_1, Postmultiplier_Experimental_1,
Dilation_Experimental_1, CFG_Scale_Scale_Experimental_1) = block("Набор 1", 625)
(Enable_Experimental_2, Timestep_Experimental_2, Scale_Experimental_2,
Premultiplier_Experimental_2, Postmultiplier_Experimental_2,
Dilation_Experimental_2, CFG_Scale_Scale_Experimental_2) = block("Набор 2", 0)
(Enable_Experimental_3, Timestep_Experimental_3, Scale_Experimental_3,
Premultiplier_Experimental_3, Postmultiplier_Experimental_3,
Dilation_Experimental_3, CFG_Scale_Scale_Experimental_3) = block("Набор 3", 750)
(Enable_Experimental_4, Timestep_Experimental_4, Scale_Experimental_4,
Premultiplier_Experimental_4, Postmultiplier_Experimental_4,
Dilation_Experimental_4, CFG_Scale_Scale_Experimental_4) = block("Набор 4", 750)
Enable_Experimental.change(toggle, Enable_Experimental, ExpGrp)
# Кривая
with gr.Accordion(label="Глобальная кривая масштаба", open=False):
Enable_Curve = gr.Checkbox(value=False, label="Включить секцию")
with gr.Group(visible=False) as CurveGrp:
Curve_Type = gr.Dropdown(choices=["linear","cosine","sigmoid"], value="linear", label="Тип кривой")
with gr.Row():
Curve_t_start = gr.Number(value=800, label="t_start")
Curve_t_end = gr.Number(value=200, label="t_end")
with gr.Row():
Curve_scale_start = gr.Number(value=1.0, label="scale_start")
Curve_scale_end = gr.Number(value=1.0, label="scale_end")
Curve_alpha = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="alpha (компенсация)")
Min_feature = gr.Slider(2, 64, value=8, step=1, label="Минимальный размер фичей")
with gr.Row():
Auto_end_enable = gr.Checkbox(value=False, label="Автоподбор scale_end")
Auto_end_strength = gr.Slider(0.0, 1.0, value=0.35, step=0.05, label="Сила автоподбора")
Enable_Curve.change(toggle, Enable_Curve, CurveGrp)
# Импорт
with gr.Accordion(label="Импорт профиля (JSON)", open=False):
En_Import = gr.Checkbox(value=False, label="Включить секцию")
with gr.Group(visible=False) as ImportGrp:
Use_Import = gr.Checkbox(value=False, label="Применить JSON ниже")
Json_Profile = gr.Textbox(value="", lines=6, label="JSON: actions/experimental/curve/runtime")
En_Import.change(toggle, En_Import, ImportGrp)
# Пресеты
with gr.Accordion(label="Пресеты", open=False):
En_Presets = gr.Checkbox(value=False, label="Включить секцию")
with gr.Group(visible=False) as PresetGrp:
Preset_Apply = gr.Checkbox(value=False, label="Выполнить действие при генерации")
Preset_Action = gr.Dropdown(choices=["Сохранить","Загрузить"], value="Загрузить", label="Действие")
with gr.Row():
Preset_Name = gr.Textbox(value="", label="Имя пресета")
Preset_Existing = gr.Dropdown(choices=preset_names or [""],
value=(preset_names[0] if preset_names else ""),
label="Выбрать существующий")
gr.Markdown("Подсказка: при «Загрузить» используется поле «Имя пресета», если оно заполнено.")
En_Presets.change(toggle, En_Presets, PresetGrp)
# ===== Выполнение =====
with gr.TabItem("Выполнение"):
En_Runtime = gr.Checkbox(value=True, label="Включить секцию")
with gr.Group(visible=True) as RuntimeGrp:
with gr.Row():
Timestep_Policy = gr.Dropdown(choices=["first","min","max","mean"], value="min", label="Политика timestep")
Interp_Method = gr.Dropdown(choices=["nearest","bilinear","bicubic","area"], value="bicubic", label="Интерполяция")
with gr.Row():
Interp_Antialias = gr.Checkbox(value=True, label="Антиалиасинг (bilinear/bicubic)")
Channels_Last = gr.Checkbox(value=False, label="Оптимизация channels_last")
with gr.Row():
Enable_Soft_Clamp = gr.Checkbox(value=False, label="Мягкий клип амплитуды")
Soft_Clamp_Beta = gr.Slider(0.0, 5.0, value=1.5, step=0.1, label="beta (mean±beta·std)")
with gr.Row():
Min_Depth = gr.Number(value=0, label="Мин. глубина", precision=0)
Max_Depth = gr.Number(value=999, label="Макс. глубина", precision=0)
En_Runtime.change(toggle, En_Runtime, RuntimeGrp)
# ===== Справка =====
with gr.TabItem("Справка"):
gr.Markdown("""
**Глобальные тумблеры** у каждой секции управляют и логикой, и показом.
Если секция выключена — её параметры игнорируются в `process()`.
""")
flat = [Enable_Ext,
En_Main, Enable_1, Timestep_1, Depth_1, Scale_1, Enable_2, Timestep_2, Depth_2, Scale_2,
En_Adv]
for en, ts, dp, sc in rows: flat += [en, ts, dp, sc]
flat += [
Enable_Experimental,
Enable_Experimental_1, Timestep_Experimental_1, Scale_Experimental_1,
Premultiplier_Experimental_1, Postmultiplier_Experimental_1,
Dilation_Experimental_1, CFG_Scale_Scale_Experimental_1,
Enable_Experimental_2, Timestep_Experimental_2, Scale_Experimental_2,
Premultiplier_Experimental_2, Postmultiplier_Experimental_2,
Dilation_Experimental_2, CFG_Scale_Scale_Experimental_2,
Enable_Experimental_3, Timestep_Experimental_3, Scale_Experimental_3,
Premultiplier_Experimental_3, Postmultiplier_Experimental_3,
Dilation_Experimental_3, CFG_Scale_Scale_Experimental_3,
Enable_Experimental_4, Timestep_Experimental_4, Scale_Experimental_4,
Premultiplier_Experimental_4, Postmultiplier_Experimental_4,
Dilation_Experimental_4, CFG_Scale_Scale_Experimental_4,
Enable_Curve, Curve_Type, Curve_t_start, Curve_t_end,
Curve_scale_start, Curve_scale_end, Curve_alpha, Min_feature,
Auto_end_enable, Auto_end_strength,
En_Import, Use_Import, Json_Profile,
En_Presets, Preset_Apply, Preset_Action, Preset_Name, Preset_Existing,
En_Runtime, Timestep_Policy, Interp_Method, Interp_Antialias, Channels_Last,
Enable_Soft_Clamp, Soft_Clamp_Beta, Min_Depth, Max_Depth
]
return flat
def process(self, p, *args):
if not isinstance(sd_unet.current_unet, DSHF.DeepShrinkHiresFixUNet): return
it = iter(args)
def nxt(): return next(it)
enable_ext = bool(nxt())
if not enable_ext: return
# Основные пороги
en_main = bool(nxt())
base_rules = []
for _ in range(2):
base_rules.append((
bool(nxt()), _to_scalar(nxt()), int(_to_scalar(nxt())), float(_to_scalar(nxt()))
))
# Расширенные пороги
en_adv = bool(nxt())
adv_rules = []
for _ in range(6):
adv_rules.append((
bool(nxt()), _to_scalar(nxt()), int(_to_scalar(nxt())), float(_to_scalar(nxt()))
))
DSHF.dshf_actions.clear()
rules = (base_rules if en_main else [(False,0,0,1.0)]*2) + (adv_rules if en_adv else [(False,0,0,1.0)]*6)
for (en,ts,dp,sc) in rules:
DSHF.dshf_actions.append(DSHFAction(bool(en), float(ts), int(dp), float(sc)))
# Экспериментальные
DSHF.enableExperimental = bool(nxt())
exp_sets = []
for _ in range(4):
en = bool(nxt()); ts = _to_scalar(nxt())
sc = _parse_number_list(str(nxt()), as_int=False)
pre = _parse_number_list(str(nxt()), as_int=False)
post = _parse_number_list(str(nxt()), as_int=False)
dil = _parse_number_list(str(nxt()), as_int=True)
cfg = _parse_number_list(str(nxt()), as_int=False)
exp_sets.append(DSHFExperimentalAction(en, ts, sc, pre, post, dil, cfg))
DSHF.dshf_experimental_actions = exp_sets if DSHF.enableExperimental else []
# Кривая
DSHF.enable_curve = bool(nxt())
curve_type = str(nxt()); t0 = _to_scalar(nxt()); t1 = _to_scalar(nxt())
s0 = float(_to_scalar(nxt())); s1 = float(_to_scalar(nxt()))
alpha = float(_clamp(_to_scalar(nxt()),0.0,1.0))
minfeat = int(_clamp(_to_scalar(nxt()),2,256))
auto_en = bool(nxt()); auto_k = float(_clamp(_to_scalar(nxt()),0.0,1.0))
if DSHF.enable_curve:
DSHF.curve_type, DSHF.curve_t_start, DSHF.curve_t_end = curve_type, t0, t1
DSHF.curve_scale_start, DSHF.curve_scale_end = s0, s1
DSHF.curve_alpha, DSHF.min_feature_size = alpha, minfeat
DSHF.auto_end_enable, DSHF.auto_end_strength = auto_en, auto_k
else:
DSHF.auto_end_enable = False
# Импорт
en_import = bool(nxt())
use_import = bool(nxt()); json_text = str(nxt() or "").strip()
if en_import and use_import and json_text:
try: _apply_profile_dict(json.loads(json_text))
except Exception as e: print(f"[DSHF] Ошибка JSON-профиля: {e}")
# Пресеты
en_preset = bool(nxt())
if en_preset:
preset_apply = bool(nxt()); action = str(nxt() or ""); name = str(nxt() or "").strip(); existing = str(nxt() or "").strip()
if preset_apply:
store = _load_all_presets(); bag = store.get("presets", {})
if action == "Сохранить":
key = name or existing
if key:
bag[key] = _build_profile_dict(); store["presets"] = bag; _save_all_presets(store)
print(f"[DSHF] Пресет сохранён: '{key}'")
else:
key = name or existing
prof = bag.get(key)
if prof: _apply_profile_dict(prof); print(f"[DSHF] Пресет загружен: '{key}'")
else:
_ = nxt(); _ = nxt(); _ = nxt(); _ = nxt()
# Выполнение
en_run = bool(nxt())
pol = str(nxt()); im = str(nxt()); aa = bool(nxt()); chlast = bool(nxt())
sclamp = bool(nxt()); beta = float(_clamp(_to_scalar(nxt()),0.0,5.0))
mind = int(_clamp(_to_scalar(nxt()),0,99)); maxd = int(_clamp(_to_scalar(nxt()),0,99))
if en_run:
DSHF.timestep_policy, DSHF.interp_method, DSHF.interp_antialias = pol, im, aa
DSHF.channels_last, DSHF.enable_soft_clamp, DSHF.soft_clamp_beta = chlast, sclamp, beta
DSHF.min_depth, DSHF.max_depth = mind, maxd
auto = self._auto_scale_end(p)
if auto is not None: DSHF.curve_scale_end = float(auto)
# ---------------- Обёртки Conv2d ----------------
class DSHF_Scale(torch.nn.Module):
def __init__(self, conv2d_ref: List[torch.nn.Conv2d]): super().__init__(); self.conv2d_ref = conv2d_ref
def forward(self, h: torch.Tensor):
exp = DSHF._active_experimental()
if exp is not None:
idx = DSHF.currentConv
h = _resize(h, 1.0/_get_or_last(exp.scales, idx, 1.0))
conv = self.conv2d_ref[0]
k = conv.kernel_size if isinstance(conv.kernel_size, tuple) else (conv.kernel_size, conv.kernel_size)
if max(k)>1:
dil = int(_get_or_last(exp.dilations, idx, 1)); conv.dilation = (dil,dil); conv.padding = (dil,dil)
else:
conv.dilation = (1,1); conv.padding = (0,0)
return h
class DSHF_Unscale(torch.nn.Module):
def forward(self, h: torch.Tensor):
exp = DSHF._active_experimental()
if exp is not None:
idx = DSHF.currentConv
s = _get_or_last(exp.scales, idx, 1.0)
if s != 1.0:
h = _resize(h, s)
if DSHF.curve_alpha != 0.0: h = h * (s ** DSHF.curve_alpha)
h = DSHF._soft_clamp(h); DSHF.currentConv += 1; return h
class DSHF_InMul(torch.nn.Module):
def forward(self, h: torch.Tensor):
exp = DSHF._active_experimental()
if exp is not None:
mul = _get_or_last(exp.in_multipliers, DSHF.currentBlock, 1.0)
if mul != 1.0: return h * mul
return h
class DSHF_OutMul(torch.nn.Module):
def forward(self, h: torch.Tensor):
exp = DSHF._active_experimental()
if exp is not None:
mul = _get_or_last(exp.out_multipliers, DSHF.currentBlock, 1.0)
if mul != 1.0: h = h * mul
return DSHF._soft_clamp(h)
# --------------- Подменённый U-Net ---------------
class DeepShrinkHiresFixUNet(sd_unet.SdUnet):
def __init__(self, _model):
super().__init__(); self.model = _model.to(devices.device)
for i, ib in enumerate(self.model.input_blocks):
for j, layer in enumerate(ib):
if isinstance(layer, ResBlock):
for k, il in enumerate(layer.in_layers):
if isinstance(il, torch.nn.Conv2d):
self.model.input_blocks[i][j].in_layers[k] = torch.nn.Sequential(DSHF.DSHF_Scale([il]), il, DSHF.DSHF_Unscale(), DSHF.DSHF_InMul())
for k, ol in enumerate(layer.out_layers):
if isinstance(ol, torch.nn.Conv2d):
self.model.input_blocks[i][j].out_layers[k] = torch.nn.Sequential(DSHF.DSHF_Scale([ol]), ol, DSHF.DSHF_Unscale(), DSHF.DSHF_OutMul())
else:
if isinstance(layer, torch.nn.Conv2d):
self.model.input_blocks[i][j] = torch.nn.Sequential(DSHF.DSHF_Scale([layer]), layer, DSHF.DSHF_Unscale())
if isinstance(layer, Downsample):
layer.op = torch.nn.Sequential(DSHF.DSHF_Scale([layer.op]), layer.op, DSHF.DSHF_Unscale())
if isinstance(layer, Upsample):
layer.conv = torch.nn.Sequential(DSHF.DSHF_Scale([layer.conv]), layer.conv, DSHF.DSHF_Unscale())
for j, layer in enumerate(self.model.middle_block):
if isinstance(layer, ResBlock):
for k, il in enumerate(layer.in_layers):
if isinstance(il, torch.nn.Conv2d):
self.model.middle_block[j].in_layers[k] = torch.nn.Sequential(DSHF.DSHF_Scale([il]), il, DSHF.DSHF_Unscale(), DSHF.DSHF_InMul())
for k, ol in enumerate(layer.out_layers):
if isinstance(ol, torch.nn.Conv2d):
self.model.middle_block[j].out_layers[k] = torch.nn.Sequential(DSHF.DSHF_Scale([ol]), ol, DSHF.DSHF_Unscale(), DSHF.DSHF_OutMul())
else:
if isinstance(layer, torch.nn.Conv2d):
self.model.middle_block[j] = torch.nn.Sequential(DSHF.DSHF_Scale([layer]), layer, DSHF.DSHF_Unscale())
if isinstance(layer, Downsample):
layer.op = torch.nn.Sequential(DSHF.DSHF_Scale([layer.op]), layer.op, DSHF.DSHF_Unscale())
if isinstance(layer, Upsample):
layer.conv = torch.nn.Sequential(DSHF.DSHF_Scale([layer.conv]), layer.conv, DSHF.DSHF_Unscale())
for i, ob in enumerate(self.model.output_blocks):
for j, layer in enumerate(ob):
if isinstance(layer, ResBlock):
for k, il in enumerate(layer.in_layers):
if isinstance(il, torch.nn.Conv2d):
self.model.output_blocks[i][j].in_layers[k] = torch.nn.Sequential(DSHF.DSHF_Scale([il]), il, DSHF.DSHF_Unscale(), DSHF.DSHF_InMul())
for k, ol in enumerate(layer.out_layers):
if isinstance(ol, torch.nn.Conv2d):
self.model.output_blocks[i][j].out_layers[k] = torch.nn.Sequential(DSHF.DSHF_Scale([ol]), ol, DSHF.DSHF_Unscale(), DSHF.DSHF_OutMul())
else:
if isinstance(layer, torch.nn.Conv2d):
self.model.output_blocks[i][j] = torch.nn.Sequential(DSHF.DSHF_Scale([layer]), layer, DSHF.DSHF_Unscale())
if isinstance(layer, Downsample):
layer.op = torch.nn.Sequential(DSHF.DSHF_Scale([layer.op]), layer.op, DSHF.DSHF_Unscale())
if isinstance(layer, Upsample):
layer.conv = torch.nn.Sequential(DSHF.DSHF_Scale([layer.conv]), layer.conv, DSHF.DSHF_Unscale())
for i, m in enumerate(self.model.out):
if isinstance(m, torch.nn.Conv2d):
self.model.out[i] = torch.nn.Sequential(DSHF.DSHF_Scale([m]), m, DSHF.DSHF_Unscale())
def forward(self, x, timesteps, context, y=None, **kwargs):
assert (y is not None) == (self.model.num_classes is not None), "must specify y iff class-conditional"
if DSHF.channels_last: x = x.contiguous(memory_format=torch.channels_last)
hs = []
emb = self.model.time_embed(timestep_embedding(timesteps, self.model.model_channels, repeat_only=False))
if self.model.num_classes is not None:
assert y.shape[0]==x.shape[0]; emb = emb + self.model.label_emb(y)
h = x.type(self.model.dtype); depth = 0
DSHF.currentBlock = 0; DSHF.currentConv = 0; DSHF.currentTimestep = DSHF._pick_timestep_scalar(timesteps)
for module in self.model.input_blocks:
context_tmp = context
scale = DSHF._block_scale(depth)
if scale is not None: h = _resize(h, 1.0/float(scale))
exp = DSHF._active_experimental()
if exp is not None:
cfg_mul = _get_or_last(exp.cfg_scale_scale, DSHF.currentBlock, 1.0)
context_tmp = context * float(cfg_mul)
h = module(h, emb, context_tmp); hs.append(h); depth += 1; DSHF.currentBlock += 1
context_tmp = context; scale = DSHF._block_scale(depth)
if scale is not None: h = _resize(h, 1.0/float(scale))
exp = DSHF._active_experimental()
if exp is not None:
cfg_mul = _get_or_last(exp.cfg_scale_scale, DSHF.currentBlock, 1.0)
context_tmp = context * float(cfg_mul)
h = self.model.middle_block(h, emb, context_tmp)
scale = DSHF._block_scale(depth)
if scale is not None: h = _resize(h, float(scale))
DSHF.currentBlock += 1
for module in self.model.output_blocks:
context_tmp = context
exp = DSHF._active_experimental()
if exp is not None:
cfg_mul = _get_or_last(exp.cfg_scale_scale, DSHF.currentBlock, 1.0)
context_tmp = context * float(cfg_mul)
depth -= 1; h = torch.cat([h, hs.pop()], dim=1); h = module(h, emb, context_tmp)
scale = DSHF._block_scale(depth)
if scale is not None: h = _resize(h, float(scale))
DSHF.currentBlock += 1
h = h.type(x.dtype)
return self.model.id_predictor(h) if self.model.predict_codebook_ids else self.model.out(h)
# Регистрация U-Net
DeepShrinkHiresFixUNetOption = sd_unet.SdUnetOption()
DeepShrinkHiresFixUNetOption.label = "Deep Shrink Hires.fix"
DeepShrinkHiresFixUNetOption.create_unet = lambda: DSHF.DeepShrinkHiresFixUNet(shared.sd_model.model.diffusion_model)
script_callbacks.on_list_unets(lambda unets: unets.append(DeepShrinkHiresFixUNetOption))