|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import List, Optional, Dict, Any |
|
|
import json, os |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
import modules.devices as devices |
|
|
import modules.scripts as scripts |
|
|
import modules.script_callbacks as script_callbacks |
|
|
import modules.sd_unet as sd_unet |
|
|
import modules.shared as shared |
|
|
|
|
|
from ldm.modules.attention import SpatialTransformer |
|
|
from ldm.modules.diffusionmodules.openaimodel import Upsample, Downsample, ResBlock |
|
|
from ldm.modules.diffusionmodules.util import timestep_embedding |
|
|
|
|
|
|
|
|
|
|
|
def _to_scalar(x) -> float: |
|
|
if isinstance(x, torch.Tensor): return float(x.item()) |
|
|
return float(x) |
|
|
|
|
|
def _clamp(v: float, lo: float, hi: float) -> float: |
|
|
return max(lo, min(hi, v)) |
|
|
|
|
|
def _safe_size(h: torch.Tensor, scale_factor: float) -> tuple[int, int]: |
|
|
h_in, w_in = h.shape[-2], h.shape[-1] |
|
|
h_out = max(2, int(round(h_in * scale_factor))) |
|
|
w_out = max(2, int(round(w_in * scale_factor))) |
|
|
if h_out < DSHF.min_feature_size or w_out < DSHF.min_feature_size: |
|
|
return h_in, w_in |
|
|
return h_out, w_out |
|
|
|
|
|
def _interpolate(img: torch.Tensor, size: tuple[int, int]) -> torch.Tensor: |
|
|
if size == img.shape[-2:]: return img |
|
|
dtype = img.dtype |
|
|
mode = DSHF.interp_method |
|
|
antialias = bool(DSHF.interp_antialias) |
|
|
try: |
|
|
out = F.interpolate( |
|
|
img.float(), size=size, mode=mode, |
|
|
align_corners=False if mode in ("bilinear","bicubic") else None, |
|
|
antialias=antialias if mode in ("bilinear","bicubic") else False |
|
|
) |
|
|
except TypeError: |
|
|
out = F.interpolate( |
|
|
img.float(), size=size, mode=mode, |
|
|
align_corners=False if mode in ("bilinear","bicubic") else None |
|
|
) |
|
|
return out.to(dtype) |
|
|
|
|
|
def _resize(h: torch.Tensor, scale_factor: float) -> torch.Tensor: |
|
|
if scale_factor == 1.0: return h |
|
|
return _interpolate(h, _safe_size(h, scale_factor)) |
|
|
|
|
|
def _parse_number_list(text: str, as_int: bool = False) -> List[float]: |
|
|
if text is None: raise ValueError("Пустая строка параметров.") |
|
|
values: List[float] = [] |
|
|
for raw in str(text).replace("\n"," ").split(";"): |
|
|
s = raw.strip() |
|
|
if not s: continue |
|
|
if "/" in s: |
|
|
a,b = s.split("/",1); val = float(a.strip())/float(b.strip()) |
|
|
else: |
|
|
val = float(s) |
|
|
values.append(val) |
|
|
if not values: raise ValueError("Не найдено ни одного валидного значения.") |
|
|
return [int(round(v)) for v in values] if as_int else values |
|
|
|
|
|
def _get_or_last(seq: List[float], index: int, default: float) -> float: |
|
|
if not seq: return default |
|
|
return seq[index] if index < len(seq) else seq[-1] |
|
|
|
|
|
def _preset_path() -> str: |
|
|
return os.path.join(os.path.dirname(__file__), "dshf_presets.json") |
|
|
|
|
|
def _load_all_presets() -> Dict[str, Any]: |
|
|
p = _preset_path() |
|
|
if not os.path.exists(p): return {"version": 2, "presets": {}} |
|
|
try: |
|
|
with open(p, "r", encoding="utf-8") as f: data = json.load(f) |
|
|
return data if "presets" in data else {"version": 2, "presets": {}} |
|
|
except Exception: |
|
|
return {"version": 2, "presets": {}} |
|
|
|
|
|
def _save_all_presets(data: Dict[str, Any]) -> None: |
|
|
p = _preset_path() |
|
|
try: |
|
|
with open(p, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) |
|
|
except Exception as e: |
|
|
print(f"[DSHF] Не удалось сохранить пресеты: {e}") |
|
|
|
|
|
def _build_profile_dict() -> Dict[str, Any]: |
|
|
return { |
|
|
"version": 2, |
|
|
"actions": [{"enable": a.enable,"timestep":a.timestep,"depth":a.depth,"scale":a.scale} for a in DSHF.dshf_actions], |
|
|
"experimental_enable": DSHF.enableExperimental, |
|
|
"experimental": [{ |
|
|
"enable": e.enable,"timestep":e.timestep,"scales":e.scales, |
|
|
"in_multipliers":e.in_multipliers,"out_multipliers":e.out_multipliers, |
|
|
"dilations":e.dilations,"cfg_scale_scale":e.cfg_scale_scale |
|
|
} for e in DSHF.dshf_experimental_actions], |
|
|
"curve": { |
|
|
"enable": DSHF.enable_curve,"type": DSHF.curve_type, |
|
|
"t_start": DSHF.curve_t_start,"t_end": DSHF.curve_t_end, |
|
|
"scale_start": DSHF.curve_scale_start,"scale_end": DSHF.curve_scale_end, |
|
|
"alpha": DSHF.curve_alpha,"min_feature": DSHF.min_feature_size, |
|
|
"auto_end_enable": DSHF.auto_end_enable,"auto_end_strength": DSHF.auto_end_strength |
|
|
}, |
|
|
"runtime": { |
|
|
"timestep_policy": DSHF.timestep_policy,"interp_method": DSHF.interp_method, |
|
|
"interp_antialias": DSHF.interp_antialias,"channels_last": DSHF.channels_last, |
|
|
"enable_soft_clamp": DSHF.enable_soft_clamp,"soft_clamp_beta": DSHF.soft_clamp_beta, |
|
|
"min_depth": DSHF.min_depth,"max_depth": DSHF.max_depth |
|
|
} |
|
|
} |
|
|
|
|
|
def _apply_profile_dict(data: Dict[str, Any]) -> None: |
|
|
try: |
|
|
DSHF.dshf_actions.clear() |
|
|
for a in data.get("actions", []): |
|
|
DSHF.dshf_actions.append(DSHFAction(bool(a.get("enable",True)),float(a.get("timestep",0)), |
|
|
int(a.get("depth",0)),float(a.get("scale",1.0)))) |
|
|
DSHF.enableExperimental = bool(data.get("experimental_enable", False)) |
|
|
DSHF.dshf_experimental_actions.clear() |
|
|
for e in data.get("experimental", []): |
|
|
DSHF.dshf_experimental_actions.append(DSHFExperimentalAction( |
|
|
bool(e.get("enable",False)), float(e.get("timestep",0)), |
|
|
list(map(float, e.get("scales",[]))), |
|
|
list(map(float, e.get("in_multipliers",[]))), |
|
|
list(map(float, e.get("out_multipliers",[]))), |
|
|
list(map(int, e.get("dilations",[]))), |
|
|
list(map(float, e.get("cfg_scale_scale",[]))), |
|
|
)) |
|
|
c = data.get("curve", {}) |
|
|
DSHF.enable_curve = bool(c.get("enable", False)) |
|
|
DSHF.curve_type = str(c.get("type","linear")) |
|
|
DSHF.curve_t_start = float(c.get("t_start",800)) |
|
|
DSHF.curve_t_end = float(c.get("t_end",200)) |
|
|
DSHF.curve_scale_start = float(c.get("scale_start",1.0)) |
|
|
DSHF.curve_scale_end = float(c.get("scale_end",1.0)) |
|
|
DSHF.curve_alpha = float(_clamp(float(c.get("alpha",0.5)),0.0,1.0)) |
|
|
DSHF.min_feature_size = int(_clamp(float(c.get("min_feature",8)),2,256)) |
|
|
DSHF.auto_end_enable = bool(c.get("auto_end_enable", False)) |
|
|
DSHF.auto_end_strength = float(_clamp(float(c.get("auto_end_strength",0.35)),0.0,1.0)) |
|
|
r = data.get("runtime", {}) |
|
|
DSHF.timestep_policy = str(r.get("timestep_policy", DSHF.timestep_policy)) |
|
|
DSHF.interp_method = str(r.get("interp_method", DSHF.interp_method)) |
|
|
DSHF.interp_antialias = bool(r.get("interp_antialias", DSHF.interp_antialias)) |
|
|
DSHF.channels_last = bool(r.get("channels_last", DSHF.channels_last)) |
|
|
DSHF.enable_soft_clamp = bool(r.get("enable_soft_clamp", DSHF.enable_soft_clamp)) |
|
|
DSHF.soft_clamp_beta = float(_clamp(float(r.get("soft_clamp_beta", DSHF.soft_clamp_beta)),0.0,5.0)) |
|
|
DSHF.min_depth = int(_clamp(int(r.get("min_depth", DSHF.min_depth)),0,99)) |
|
|
DSHF.max_depth = int(_clamp(int(r.get("max_depth", DSHF.max_depth)),0,99)) |
|
|
except Exception as e: |
|
|
print(f"[DSHF] Ошибка применения профиля: {e}") |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DSHFAction: |
|
|
enable: bool; timestep: float; depth: int; scale: float |
|
|
|
|
|
@dataclass |
|
|
class DSHFExperimentalAction: |
|
|
enable: bool; timestep: float |
|
|
scales: List[float]; in_multipliers: List[float]; out_multipliers: List[float] |
|
|
dilations: List[int]; cfg_scale_scale: List[float] |
|
|
|
|
|
|
|
|
|
|
|
class DSHF(scripts.Script): |
|
|
dshf_actions: List[DSHFAction] = [] |
|
|
enableExperimental: bool = False |
|
|
dshf_experimental_actions: List[DSHFExperimentalAction] = [] |
|
|
|
|
|
currentBlock: int = 0 |
|
|
currentConv: int = 0 |
|
|
currentTimestep: float = 1000.0 |
|
|
|
|
|
enable_curve: bool = False |
|
|
curve_type: str = "linear" |
|
|
curve_t_start: float = 800.0 |
|
|
curve_t_end: float = 200.0 |
|
|
curve_scale_start: float = 1.0 |
|
|
curve_scale_end: float = 1.0 |
|
|
curve_alpha: float = 0.5 |
|
|
min_feature_size: int = 8 |
|
|
auto_end_enable: bool = False |
|
|
auto_end_strength: float = 0.35 |
|
|
|
|
|
timestep_policy: str = "min" |
|
|
interp_method: str = "bicubic" |
|
|
interp_antialias: bool = True |
|
|
channels_last: bool = False |
|
|
enable_soft_clamp: bool = False |
|
|
soft_clamp_beta: float = 1.5 |
|
|
min_depth: int = 0 |
|
|
max_depth: int = 999 |
|
|
|
|
|
def title(self): return "Deep Shrink Hires.fix (RU++ v2.1)" |
|
|
def show(self, is_img2img): return scripts.AlwaysVisible |
|
|
|
|
|
@staticmethod |
|
|
def _active_experimental() -> Optional[DSHFExperimentalAction]: |
|
|
if not DSHF.enableExperimental: return None |
|
|
ts = DSHF.currentTimestep |
|
|
for a in DSHF.dshf_experimental_actions: |
|
|
if a.enable and a.timestep <= ts: return a |
|
|
return None |
|
|
|
|
|
@staticmethod |
|
|
def _curve_weight(ts: float) -> Optional[float]: |
|
|
if not DSHF.enable_curve: return None |
|
|
t0, t1 = DSHF.curve_t_start, DSHF.curve_t_end |
|
|
if t0 == t1: return DSHF.curve_scale_end |
|
|
x = _clamp((ts - t1) / (t0 - t1), 0.0, 1.0) |
|
|
if DSHF.curve_type == "linear": w = x |
|
|
elif DSHF.curve_type == "cosine": w = 0.5 - 0.5 * torch.cos(torch.tensor(x) * torch.pi).item() |
|
|
else: |
|
|
w = 1.0 / (1.0 + torch.exp(torch.tensor(-10.0 * (x - 0.5)))).item() |
|
|
s0, s1 = float(DSHF.curve_scale_start), float(DSHF.curve_scale_end) |
|
|
return _clamp(s0 + (s1 - s0) * w, 0.25, 4.0) |
|
|
|
|
|
@staticmethod |
|
|
def _block_scale(depth: int) -> Optional[float]: |
|
|
if depth < DSHF.min_depth or depth > DSHF.max_depth: return None |
|
|
ts = DSHF.currentTimestep |
|
|
rule_scale = None |
|
|
for a in DSHF.dshf_actions: |
|
|
if a.enable and a.depth == depth and a.timestep <= ts: |
|
|
rule_scale = a.scale; break |
|
|
curve_scale = DSHF._curve_weight(ts) |
|
|
if rule_scale is None and curve_scale is None: return None |
|
|
if rule_scale is None: return curve_scale |
|
|
if curve_scale is None: return rule_scale |
|
|
return _clamp(rule_scale * curve_scale, 0.25, 4.0) |
|
|
|
|
|
@staticmethod |
|
|
def _auto_scale_end(p) -> Optional[float]: |
|
|
if not DSHF.auto_end_enable or not DSHF.enable_curve: return None |
|
|
try: |
|
|
bw,bh = int(getattr(p,"width",0)), int(getattr(p,"height",0)) |
|
|
if bw<=0 or bh<=0: return None |
|
|
tw,th = bw,bh |
|
|
if getattr(p,"enable_hr",False): |
|
|
hrx,hry = int(getattr(p,"hr_resize_x",0)), int(getattr(p,"hr_resize_y",0)) |
|
|
hrs = float(getattr(p,"hr_scale",0.0) or 0.0) |
|
|
if hrx>0 and hry>0: tw,th = hrx,hry |
|
|
elif hrs>0.0: tw,th = int(round(bw*hrs)), int(round(bh*hrs)) |
|
|
r = ((max(1,tw*th))/max(1,bw*bh))**0.5 |
|
|
if r<=1.0: return None |
|
|
return _clamp(1.0 + float(_clamp(DSHF.auto_end_strength,0.0,1.0))*(r-1.0), 1.0, 1.7) |
|
|
except Exception: return None |
|
|
|
|
|
@staticmethod |
|
|
def _pick_timestep_scalar(timesteps: torch.Tensor) -> float: |
|
|
pol = DSHF.timestep_policy |
|
|
vals = timesteps.detach().float() |
|
|
if pol=="first": return _to_scalar(vals[0]) |
|
|
if pol=="max": return float(vals.max().item()) |
|
|
if pol=="mean": return float(vals.mean().item()) |
|
|
return float(vals.min().item()) |
|
|
|
|
|
@staticmethod |
|
|
def _soft_clamp(h: torch.Tensor) -> torch.Tensor: |
|
|
if not DSHF.enable_soft_clamp: return h |
|
|
beta = float(DSHF.soft_clamp_beta) |
|
|
if beta<=0: return h |
|
|
mean = h.mean(dim=(2,3), keepdim=True); std = h.std(dim=(2,3), keepdim=True)+1e-6 |
|
|
limit = mean + std*beta |
|
|
return torch.minimum(torch.maximum(h, -limit), limit) |
|
|
|
|
|
|
|
|
def ui(self, is_img2img): |
|
|
presets = _load_all_presets().get("presets", {}) |
|
|
preset_names = sorted(list(presets.keys())) |
|
|
def toggle(v): return gr.update(visible=bool(v)) |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.TabItem("Настройки"): |
|
|
Enable_Ext = gr.Checkbox(value=True, label="Включить расширение") |
|
|
|
|
|
|
|
|
with gr.Accordion(label="Основные пороги (1–2)", open=False): |
|
|
En_Main = gr.Checkbox(value=True, label="Включить секцию") |
|
|
with gr.Group(visible=True) as MainGrp: |
|
|
with gr.Row(): |
|
|
Enable_1 = gr.Checkbox(value=True, label="Включить правило 1") |
|
|
Timestep_1 = gr.Number(value=625, label="Timestep 1") |
|
|
Depth_1 = gr.Number(value=3, label="Глубина блока 1", precision=0) |
|
|
Scale_1 = gr.Number(value=2.0, label="Коэффициент масштаба 1") |
|
|
with gr.Row(): |
|
|
Enable_2 = gr.Checkbox(value=True, label="Включить правило 2") |
|
|
Timestep_2 = gr.Number(value=0, label="Timestep 2") |
|
|
Depth_2 = gr.Number(value=3, label="Глубина блока 2", precision=0) |
|
|
Scale_2 = gr.Number(value=2.0, label="Коэффициент масштаба 2") |
|
|
En_Main.change(toggle, En_Main, MainGrp) |
|
|
|
|
|
|
|
|
with gr.Accordion(label="Расширенные пороги (3–8)", open=False): |
|
|
En_Adv = gr.Checkbox(value=False, label="Включить секцию") |
|
|
with gr.Group(visible=False) as AdvGrp: |
|
|
rows = [] |
|
|
defaults = [(False,900,3,2.0),(False,650,3,2.0),(False,900,3,2.0), |
|
|
(False,650,3,2.0),(False,900,3,2.0),(False,650,3,2.0)] |
|
|
for idx,(en,ts,dp,sc) in enumerate(defaults, start=3): |
|
|
with gr.Row(): |
|
|
rows.append(( |
|
|
gr.Checkbox(value=en, label=f"Включить правило {idx}"), |
|
|
gr.Number(value=ts, label=f"Timestep {idx}"), |
|
|
gr.Number(value=dp, label=f"Глубина блока {idx}", precision=0), |
|
|
gr.Number(value=sc, label=f"Коэффициент масштаба {idx}") |
|
|
)) |
|
|
En_Adv.change(toggle, En_Adv, AdvGrp) |
|
|
|
|
|
|
|
|
with gr.Accordion(label="Экспериментальные (масштабы/дилатации/множители)", open=False): |
|
|
Enable_Experimental = gr.Checkbox(value=False, label="Включить секцию") |
|
|
with gr.Group(visible=False) as ExpGrp: |
|
|
def block(prefix, ts_default): |
|
|
with gr.Row(): |
|
|
en = gr.Checkbox(value=True, label=f"{prefix}: включить набор") |
|
|
ts = gr.Number(value=ts_default, label=f"{prefix}: timestep") |
|
|
with gr.Row(): |
|
|
sc = gr.Textbox(value="1; " * 52 + "1", label=f"{prefix}: масштабы (по свёрткам)", lines=2) |
|
|
with gr.Row(): |
|
|
cfg = gr.Textbox(value="1;1;1; 1;1;1; 1;1;1; 1;1;1; 1; 1;1;1; 1;1;1; 1;1;1; 1;1;1", |
|
|
label=f"{prefix}: множители CFG-scale") |
|
|
dil = gr.Textbox(value="1; " * 52 + "1", label=f"{prefix}: дилатации (по свёрткам)", lines=2) |
|
|
with gr.Row(): |
|
|
pre = gr.Textbox(value="1; " * 24 + "1", label=f"{prefix}: входные умножители (по блокам)") |
|
|
post = gr.Textbox(value="1; " * 24 + "1", label=f"{prefix}: выходные умножители (по блокам)") |
|
|
return en, ts, sc, pre, post, dil, cfg |
|
|
(Enable_Experimental_1, Timestep_Experimental_1, Scale_Experimental_1, |
|
|
Premultiplier_Experimental_1, Postmultiplier_Experimental_1, |
|
|
Dilation_Experimental_1, CFG_Scale_Scale_Experimental_1) = block("Набор 1", 625) |
|
|
(Enable_Experimental_2, Timestep_Experimental_2, Scale_Experimental_2, |
|
|
Premultiplier_Experimental_2, Postmultiplier_Experimental_2, |
|
|
Dilation_Experimental_2, CFG_Scale_Scale_Experimental_2) = block("Набор 2", 0) |
|
|
(Enable_Experimental_3, Timestep_Experimental_3, Scale_Experimental_3, |
|
|
Premultiplier_Experimental_3, Postmultiplier_Experimental_3, |
|
|
Dilation_Experimental_3, CFG_Scale_Scale_Experimental_3) = block("Набор 3", 750) |
|
|
(Enable_Experimental_4, Timestep_Experimental_4, Scale_Experimental_4, |
|
|
Premultiplier_Experimental_4, Postmultiplier_Experimental_4, |
|
|
Dilation_Experimental_4, CFG_Scale_Scale_Experimental_4) = block("Набор 4", 750) |
|
|
Enable_Experimental.change(toggle, Enable_Experimental, ExpGrp) |
|
|
|
|
|
|
|
|
with gr.Accordion(label="Глобальная кривая масштаба", open=False): |
|
|
Enable_Curve = gr.Checkbox(value=False, label="Включить секцию") |
|
|
with gr.Group(visible=False) as CurveGrp: |
|
|
Curve_Type = gr.Dropdown(choices=["linear","cosine","sigmoid"], value="linear", label="Тип кривой") |
|
|
with gr.Row(): |
|
|
Curve_t_start = gr.Number(value=800, label="t_start") |
|
|
Curve_t_end = gr.Number(value=200, label="t_end") |
|
|
with gr.Row(): |
|
|
Curve_scale_start = gr.Number(value=1.0, label="scale_start") |
|
|
Curve_scale_end = gr.Number(value=1.0, label="scale_end") |
|
|
Curve_alpha = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="alpha (компенсация)") |
|
|
Min_feature = gr.Slider(2, 64, value=8, step=1, label="Минимальный размер фичей") |
|
|
with gr.Row(): |
|
|
Auto_end_enable = gr.Checkbox(value=False, label="Автоподбор scale_end") |
|
|
Auto_end_strength = gr.Slider(0.0, 1.0, value=0.35, step=0.05, label="Сила автоподбора") |
|
|
Enable_Curve.change(toggle, Enable_Curve, CurveGrp) |
|
|
|
|
|
|
|
|
with gr.Accordion(label="Импорт профиля (JSON)", open=False): |
|
|
En_Import = gr.Checkbox(value=False, label="Включить секцию") |
|
|
with gr.Group(visible=False) as ImportGrp: |
|
|
Use_Import = gr.Checkbox(value=False, label="Применить JSON ниже") |
|
|
Json_Profile = gr.Textbox(value="", lines=6, label="JSON: actions/experimental/curve/runtime") |
|
|
En_Import.change(toggle, En_Import, ImportGrp) |
|
|
|
|
|
|
|
|
with gr.Accordion(label="Пресеты", open=False): |
|
|
En_Presets = gr.Checkbox(value=False, label="Включить секцию") |
|
|
with gr.Group(visible=False) as PresetGrp: |
|
|
Preset_Apply = gr.Checkbox(value=False, label="Выполнить действие при генерации") |
|
|
Preset_Action = gr.Dropdown(choices=["Сохранить","Загрузить"], value="Загрузить", label="Действие") |
|
|
with gr.Row(): |
|
|
Preset_Name = gr.Textbox(value="", label="Имя пресета") |
|
|
Preset_Existing = gr.Dropdown(choices=preset_names or [""], |
|
|
value=(preset_names[0] if preset_names else ""), |
|
|
label="Выбрать существующий") |
|
|
gr.Markdown("Подсказка: при «Загрузить» используется поле «Имя пресета», если оно заполнено.") |
|
|
En_Presets.change(toggle, En_Presets, PresetGrp) |
|
|
|
|
|
|
|
|
with gr.TabItem("Выполнение"): |
|
|
En_Runtime = gr.Checkbox(value=True, label="Включить секцию") |
|
|
with gr.Group(visible=True) as RuntimeGrp: |
|
|
with gr.Row(): |
|
|
Timestep_Policy = gr.Dropdown(choices=["first","min","max","mean"], value="min", label="Политика timestep") |
|
|
Interp_Method = gr.Dropdown(choices=["nearest","bilinear","bicubic","area"], value="bicubic", label="Интерполяция") |
|
|
with gr.Row(): |
|
|
Interp_Antialias = gr.Checkbox(value=True, label="Антиалиасинг (bilinear/bicubic)") |
|
|
Channels_Last = gr.Checkbox(value=False, label="Оптимизация channels_last") |
|
|
with gr.Row(): |
|
|
Enable_Soft_Clamp = gr.Checkbox(value=False, label="Мягкий клип амплитуды") |
|
|
Soft_Clamp_Beta = gr.Slider(0.0, 5.0, value=1.5, step=0.1, label="beta (mean±beta·std)") |
|
|
with gr.Row(): |
|
|
Min_Depth = gr.Number(value=0, label="Мин. глубина", precision=0) |
|
|
Max_Depth = gr.Number(value=999, label="Макс. глубина", precision=0) |
|
|
En_Runtime.change(toggle, En_Runtime, RuntimeGrp) |
|
|
|
|
|
|
|
|
with gr.TabItem("Справка"): |
|
|
gr.Markdown(""" |
|
|
**Глобальные тумблеры** у каждой секции управляют и логикой, и показом. |
|
|
Если секция выключена — её параметры игнорируются в `process()`. |
|
|
""") |
|
|
|
|
|
flat = [Enable_Ext, |
|
|
En_Main, Enable_1, Timestep_1, Depth_1, Scale_1, Enable_2, Timestep_2, Depth_2, Scale_2, |
|
|
En_Adv] |
|
|
for en, ts, dp, sc in rows: flat += [en, ts, dp, sc] |
|
|
flat += [ |
|
|
Enable_Experimental, |
|
|
Enable_Experimental_1, Timestep_Experimental_1, Scale_Experimental_1, |
|
|
Premultiplier_Experimental_1, Postmultiplier_Experimental_1, |
|
|
Dilation_Experimental_1, CFG_Scale_Scale_Experimental_1, |
|
|
Enable_Experimental_2, Timestep_Experimental_2, Scale_Experimental_2, |
|
|
Premultiplier_Experimental_2, Postmultiplier_Experimental_2, |
|
|
Dilation_Experimental_2, CFG_Scale_Scale_Experimental_2, |
|
|
Enable_Experimental_3, Timestep_Experimental_3, Scale_Experimental_3, |
|
|
Premultiplier_Experimental_3, Postmultiplier_Experimental_3, |
|
|
Dilation_Experimental_3, CFG_Scale_Scale_Experimental_3, |
|
|
Enable_Experimental_4, Timestep_Experimental_4, Scale_Experimental_4, |
|
|
Premultiplier_Experimental_4, Postmultiplier_Experimental_4, |
|
|
Dilation_Experimental_4, CFG_Scale_Scale_Experimental_4, |
|
|
Enable_Curve, Curve_Type, Curve_t_start, Curve_t_end, |
|
|
Curve_scale_start, Curve_scale_end, Curve_alpha, Min_feature, |
|
|
Auto_end_enable, Auto_end_strength, |
|
|
En_Import, Use_Import, Json_Profile, |
|
|
En_Presets, Preset_Apply, Preset_Action, Preset_Name, Preset_Existing, |
|
|
En_Runtime, Timestep_Policy, Interp_Method, Interp_Antialias, Channels_Last, |
|
|
Enable_Soft_Clamp, Soft_Clamp_Beta, Min_Depth, Max_Depth |
|
|
] |
|
|
return flat |
|
|
|
|
|
def process(self, p, *args): |
|
|
if not isinstance(sd_unet.current_unet, DSHF.DeepShrinkHiresFixUNet): return |
|
|
it = iter(args) |
|
|
def nxt(): return next(it) |
|
|
|
|
|
enable_ext = bool(nxt()) |
|
|
if not enable_ext: return |
|
|
|
|
|
|
|
|
en_main = bool(nxt()) |
|
|
base_rules = [] |
|
|
for _ in range(2): |
|
|
base_rules.append(( |
|
|
bool(nxt()), _to_scalar(nxt()), int(_to_scalar(nxt())), float(_to_scalar(nxt())) |
|
|
)) |
|
|
|
|
|
|
|
|
en_adv = bool(nxt()) |
|
|
adv_rules = [] |
|
|
for _ in range(6): |
|
|
adv_rules.append(( |
|
|
bool(nxt()), _to_scalar(nxt()), int(_to_scalar(nxt())), float(_to_scalar(nxt())) |
|
|
)) |
|
|
|
|
|
DSHF.dshf_actions.clear() |
|
|
rules = (base_rules if en_main else [(False,0,0,1.0)]*2) + (adv_rules if en_adv else [(False,0,0,1.0)]*6) |
|
|
for (en,ts,dp,sc) in rules: |
|
|
DSHF.dshf_actions.append(DSHFAction(bool(en), float(ts), int(dp), float(sc))) |
|
|
|
|
|
|
|
|
DSHF.enableExperimental = bool(nxt()) |
|
|
exp_sets = [] |
|
|
for _ in range(4): |
|
|
en = bool(nxt()); ts = _to_scalar(nxt()) |
|
|
sc = _parse_number_list(str(nxt()), as_int=False) |
|
|
pre = _parse_number_list(str(nxt()), as_int=False) |
|
|
post = _parse_number_list(str(nxt()), as_int=False) |
|
|
dil = _parse_number_list(str(nxt()), as_int=True) |
|
|
cfg = _parse_number_list(str(nxt()), as_int=False) |
|
|
exp_sets.append(DSHFExperimentalAction(en, ts, sc, pre, post, dil, cfg)) |
|
|
DSHF.dshf_experimental_actions = exp_sets if DSHF.enableExperimental else [] |
|
|
|
|
|
|
|
|
DSHF.enable_curve = bool(nxt()) |
|
|
curve_type = str(nxt()); t0 = _to_scalar(nxt()); t1 = _to_scalar(nxt()) |
|
|
s0 = float(_to_scalar(nxt())); s1 = float(_to_scalar(nxt())) |
|
|
alpha = float(_clamp(_to_scalar(nxt()),0.0,1.0)) |
|
|
minfeat = int(_clamp(_to_scalar(nxt()),2,256)) |
|
|
auto_en = bool(nxt()); auto_k = float(_clamp(_to_scalar(nxt()),0.0,1.0)) |
|
|
if DSHF.enable_curve: |
|
|
DSHF.curve_type, DSHF.curve_t_start, DSHF.curve_t_end = curve_type, t0, t1 |
|
|
DSHF.curve_scale_start, DSHF.curve_scale_end = s0, s1 |
|
|
DSHF.curve_alpha, DSHF.min_feature_size = alpha, minfeat |
|
|
DSHF.auto_end_enable, DSHF.auto_end_strength = auto_en, auto_k |
|
|
else: |
|
|
DSHF.auto_end_enable = False |
|
|
|
|
|
|
|
|
en_import = bool(nxt()) |
|
|
use_import = bool(nxt()); json_text = str(nxt() or "").strip() |
|
|
if en_import and use_import and json_text: |
|
|
try: _apply_profile_dict(json.loads(json_text)) |
|
|
except Exception as e: print(f"[DSHF] Ошибка JSON-профиля: {e}") |
|
|
|
|
|
|
|
|
en_preset = bool(nxt()) |
|
|
if en_preset: |
|
|
preset_apply = bool(nxt()); action = str(nxt() or ""); name = str(nxt() or "").strip(); existing = str(nxt() or "").strip() |
|
|
if preset_apply: |
|
|
store = _load_all_presets(); bag = store.get("presets", {}) |
|
|
if action == "Сохранить": |
|
|
key = name or existing |
|
|
if key: |
|
|
bag[key] = _build_profile_dict(); store["presets"] = bag; _save_all_presets(store) |
|
|
print(f"[DSHF] Пресет сохранён: '{key}'") |
|
|
else: |
|
|
key = name or existing |
|
|
prof = bag.get(key) |
|
|
if prof: _apply_profile_dict(prof); print(f"[DSHF] Пресет загружен: '{key}'") |
|
|
else: |
|
|
_ = nxt(); _ = nxt(); _ = nxt(); _ = nxt() |
|
|
|
|
|
|
|
|
en_run = bool(nxt()) |
|
|
pol = str(nxt()); im = str(nxt()); aa = bool(nxt()); chlast = bool(nxt()) |
|
|
sclamp = bool(nxt()); beta = float(_clamp(_to_scalar(nxt()),0.0,5.0)) |
|
|
mind = int(_clamp(_to_scalar(nxt()),0,99)); maxd = int(_clamp(_to_scalar(nxt()),0,99)) |
|
|
if en_run: |
|
|
DSHF.timestep_policy, DSHF.interp_method, DSHF.interp_antialias = pol, im, aa |
|
|
DSHF.channels_last, DSHF.enable_soft_clamp, DSHF.soft_clamp_beta = chlast, sclamp, beta |
|
|
DSHF.min_depth, DSHF.max_depth = mind, maxd |
|
|
|
|
|
auto = self._auto_scale_end(p) |
|
|
if auto is not None: DSHF.curve_scale_end = float(auto) |
|
|
|
|
|
|
|
|
class DSHF_Scale(torch.nn.Module): |
|
|
def __init__(self, conv2d_ref: List[torch.nn.Conv2d]): super().__init__(); self.conv2d_ref = conv2d_ref |
|
|
def forward(self, h: torch.Tensor): |
|
|
exp = DSHF._active_experimental() |
|
|
if exp is not None: |
|
|
idx = DSHF.currentConv |
|
|
h = _resize(h, 1.0/_get_or_last(exp.scales, idx, 1.0)) |
|
|
conv = self.conv2d_ref[0] |
|
|
k = conv.kernel_size if isinstance(conv.kernel_size, tuple) else (conv.kernel_size, conv.kernel_size) |
|
|
if max(k)>1: |
|
|
dil = int(_get_or_last(exp.dilations, idx, 1)); conv.dilation = (dil,dil); conv.padding = (dil,dil) |
|
|
else: |
|
|
conv.dilation = (1,1); conv.padding = (0,0) |
|
|
return h |
|
|
|
|
|
class DSHF_Unscale(torch.nn.Module): |
|
|
def forward(self, h: torch.Tensor): |
|
|
exp = DSHF._active_experimental() |
|
|
if exp is not None: |
|
|
idx = DSHF.currentConv |
|
|
s = _get_or_last(exp.scales, idx, 1.0) |
|
|
if s != 1.0: |
|
|
h = _resize(h, s) |
|
|
if DSHF.curve_alpha != 0.0: h = h * (s ** DSHF.curve_alpha) |
|
|
h = DSHF._soft_clamp(h); DSHF.currentConv += 1; return h |
|
|
|
|
|
class DSHF_InMul(torch.nn.Module): |
|
|
def forward(self, h: torch.Tensor): |
|
|
exp = DSHF._active_experimental() |
|
|
if exp is not None: |
|
|
mul = _get_or_last(exp.in_multipliers, DSHF.currentBlock, 1.0) |
|
|
if mul != 1.0: return h * mul |
|
|
return h |
|
|
|
|
|
class DSHF_OutMul(torch.nn.Module): |
|
|
def forward(self, h: torch.Tensor): |
|
|
exp = DSHF._active_experimental() |
|
|
if exp is not None: |
|
|
mul = _get_or_last(exp.out_multipliers, DSHF.currentBlock, 1.0) |
|
|
if mul != 1.0: h = h * mul |
|
|
return DSHF._soft_clamp(h) |
|
|
|
|
|
|
|
|
class DeepShrinkHiresFixUNet(sd_unet.SdUnet): |
|
|
def __init__(self, _model): |
|
|
super().__init__(); self.model = _model.to(devices.device) |
|
|
for i, ib in enumerate(self.model.input_blocks): |
|
|
for j, layer in enumerate(ib): |
|
|
if isinstance(layer, ResBlock): |
|
|
for k, il in enumerate(layer.in_layers): |
|
|
if isinstance(il, torch.nn.Conv2d): |
|
|
self.model.input_blocks[i][j].in_layers[k] = torch.nn.Sequential(DSHF.DSHF_Scale([il]), il, DSHF.DSHF_Unscale(), DSHF.DSHF_InMul()) |
|
|
for k, ol in enumerate(layer.out_layers): |
|
|
if isinstance(ol, torch.nn.Conv2d): |
|
|
self.model.input_blocks[i][j].out_layers[k] = torch.nn.Sequential(DSHF.DSHF_Scale([ol]), ol, DSHF.DSHF_Unscale(), DSHF.DSHF_OutMul()) |
|
|
else: |
|
|
if isinstance(layer, torch.nn.Conv2d): |
|
|
self.model.input_blocks[i][j] = torch.nn.Sequential(DSHF.DSHF_Scale([layer]), layer, DSHF.DSHF_Unscale()) |
|
|
if isinstance(layer, Downsample): |
|
|
layer.op = torch.nn.Sequential(DSHF.DSHF_Scale([layer.op]), layer.op, DSHF.DSHF_Unscale()) |
|
|
if isinstance(layer, Upsample): |
|
|
layer.conv = torch.nn.Sequential(DSHF.DSHF_Scale([layer.conv]), layer.conv, DSHF.DSHF_Unscale()) |
|
|
for j, layer in enumerate(self.model.middle_block): |
|
|
if isinstance(layer, ResBlock): |
|
|
for k, il in enumerate(layer.in_layers): |
|
|
if isinstance(il, torch.nn.Conv2d): |
|
|
self.model.middle_block[j].in_layers[k] = torch.nn.Sequential(DSHF.DSHF_Scale([il]), il, DSHF.DSHF_Unscale(), DSHF.DSHF_InMul()) |
|
|
for k, ol in enumerate(layer.out_layers): |
|
|
if isinstance(ol, torch.nn.Conv2d): |
|
|
self.model.middle_block[j].out_layers[k] = torch.nn.Sequential(DSHF.DSHF_Scale([ol]), ol, DSHF.DSHF_Unscale(), DSHF.DSHF_OutMul()) |
|
|
else: |
|
|
if isinstance(layer, torch.nn.Conv2d): |
|
|
self.model.middle_block[j] = torch.nn.Sequential(DSHF.DSHF_Scale([layer]), layer, DSHF.DSHF_Unscale()) |
|
|
if isinstance(layer, Downsample): |
|
|
layer.op = torch.nn.Sequential(DSHF.DSHF_Scale([layer.op]), layer.op, DSHF.DSHF_Unscale()) |
|
|
if isinstance(layer, Upsample): |
|
|
layer.conv = torch.nn.Sequential(DSHF.DSHF_Scale([layer.conv]), layer.conv, DSHF.DSHF_Unscale()) |
|
|
for i, ob in enumerate(self.model.output_blocks): |
|
|
for j, layer in enumerate(ob): |
|
|
if isinstance(layer, ResBlock): |
|
|
for k, il in enumerate(layer.in_layers): |
|
|
if isinstance(il, torch.nn.Conv2d): |
|
|
self.model.output_blocks[i][j].in_layers[k] = torch.nn.Sequential(DSHF.DSHF_Scale([il]), il, DSHF.DSHF_Unscale(), DSHF.DSHF_InMul()) |
|
|
for k, ol in enumerate(layer.out_layers): |
|
|
if isinstance(ol, torch.nn.Conv2d): |
|
|
self.model.output_blocks[i][j].out_layers[k] = torch.nn.Sequential(DSHF.DSHF_Scale([ol]), ol, DSHF.DSHF_Unscale(), DSHF.DSHF_OutMul()) |
|
|
else: |
|
|
if isinstance(layer, torch.nn.Conv2d): |
|
|
self.model.output_blocks[i][j] = torch.nn.Sequential(DSHF.DSHF_Scale([layer]), layer, DSHF.DSHF_Unscale()) |
|
|
if isinstance(layer, Downsample): |
|
|
layer.op = torch.nn.Sequential(DSHF.DSHF_Scale([layer.op]), layer.op, DSHF.DSHF_Unscale()) |
|
|
if isinstance(layer, Upsample): |
|
|
layer.conv = torch.nn.Sequential(DSHF.DSHF_Scale([layer.conv]), layer.conv, DSHF.DSHF_Unscale()) |
|
|
for i, m in enumerate(self.model.out): |
|
|
if isinstance(m, torch.nn.Conv2d): |
|
|
self.model.out[i] = torch.nn.Sequential(DSHF.DSHF_Scale([m]), m, DSHF.DSHF_Unscale()) |
|
|
|
|
|
def forward(self, x, timesteps, context, y=None, **kwargs): |
|
|
assert (y is not None) == (self.model.num_classes is not None), "must specify y iff class-conditional" |
|
|
if DSHF.channels_last: x = x.contiguous(memory_format=torch.channels_last) |
|
|
hs = [] |
|
|
emb = self.model.time_embed(timestep_embedding(timesteps, self.model.model_channels, repeat_only=False)) |
|
|
if self.model.num_classes is not None: |
|
|
assert y.shape[0]==x.shape[0]; emb = emb + self.model.label_emb(y) |
|
|
h = x.type(self.model.dtype); depth = 0 |
|
|
DSHF.currentBlock = 0; DSHF.currentConv = 0; DSHF.currentTimestep = DSHF._pick_timestep_scalar(timesteps) |
|
|
|
|
|
for module in self.model.input_blocks: |
|
|
context_tmp = context |
|
|
scale = DSHF._block_scale(depth) |
|
|
if scale is not None: h = _resize(h, 1.0/float(scale)) |
|
|
exp = DSHF._active_experimental() |
|
|
if exp is not None: |
|
|
cfg_mul = _get_or_last(exp.cfg_scale_scale, DSHF.currentBlock, 1.0) |
|
|
context_tmp = context * float(cfg_mul) |
|
|
h = module(h, emb, context_tmp); hs.append(h); depth += 1; DSHF.currentBlock += 1 |
|
|
|
|
|
context_tmp = context; scale = DSHF._block_scale(depth) |
|
|
if scale is not None: h = _resize(h, 1.0/float(scale)) |
|
|
exp = DSHF._active_experimental() |
|
|
if exp is not None: |
|
|
cfg_mul = _get_or_last(exp.cfg_scale_scale, DSHF.currentBlock, 1.0) |
|
|
context_tmp = context * float(cfg_mul) |
|
|
h = self.model.middle_block(h, emb, context_tmp) |
|
|
scale = DSHF._block_scale(depth) |
|
|
if scale is not None: h = _resize(h, float(scale)) |
|
|
DSHF.currentBlock += 1 |
|
|
|
|
|
for module in self.model.output_blocks: |
|
|
context_tmp = context |
|
|
exp = DSHF._active_experimental() |
|
|
if exp is not None: |
|
|
cfg_mul = _get_or_last(exp.cfg_scale_scale, DSHF.currentBlock, 1.0) |
|
|
context_tmp = context * float(cfg_mul) |
|
|
depth -= 1; h = torch.cat([h, hs.pop()], dim=1); h = module(h, emb, context_tmp) |
|
|
scale = DSHF._block_scale(depth) |
|
|
if scale is not None: h = _resize(h, float(scale)) |
|
|
DSHF.currentBlock += 1 |
|
|
|
|
|
h = h.type(x.dtype) |
|
|
return self.model.id_predictor(h) if self.model.predict_codebook_ids else self.model.out(h) |
|
|
|
|
|
|
|
|
DeepShrinkHiresFixUNetOption = sd_unet.SdUnetOption() |
|
|
DeepShrinkHiresFixUNetOption.label = "Deep Shrink Hires.fix" |
|
|
DeepShrinkHiresFixUNetOption.create_unet = lambda: DSHF.DeepShrinkHiresFixUNet(shared.sd_model.model.diffusion_model) |
|
|
script_callbacks.on_list_unets(lambda unets: unets.append(DeepShrinkHiresFixUNetOption)) |
|
|
|