# Deep Shrink Hires.fix (RU++ v2.3.1 LTS + Experimental, FIXED) # Исправления: статические _block_scale/_curve_weight; синхронизация instance→class после process(). from dataclasses import dataclass from typing import List, Optional, Dict, Any, Tuple import json, os import gradio as gr import torch import torch.nn.functional as F import modules.devices as devices import modules.scripts as scripts import modules.script_callbacks as script_callbacks import modules.sd_unet as sd_unet import modules.shared as shared from ldm.modules.diffusionmodules.openaimodel import Upsample, Downsample, ResBlock from ldm.modules.diffusionmodules.util import timestep_embedding # -------------------------- Утилиты -------------------------- def _to_scalar(x) -> float: if isinstance(x, torch.Tensor): return float(x.item()) try: return float(x) except Exception: return 0.0 def _clamp(v: float, lo: float, hi: float) -> float: return max(lo, min(hi, v)) def _get_or_last(seq: List[float], index: int, default: float) -> float: if not seq: return default return seq[index] if index < len(seq) else seq[-1] def _safe_size(h: torch.Tensor, scale_factor: float, min_feat: int) -> Tuple[int, int]: hi, wi = h.shape[-2], h.shape[-1] ho = max(2, int(round(hi * scale_factor))) wo = max(2, int(round(wi * scale_factor))) if ho < min_feat or wo < min_feat: return hi, wi return ho, wo def _interpolate(img: torch.Tensor, size: Tuple[int, int], mode: str, antialias: bool) -> torch.Tensor: if size == img.shape[-2:]: return img dtype = img.dtype try: out = F.interpolate( img.float(), size=size, mode=mode, align_corners=False if mode in ("bilinear", "bicubic") else None, antialias=(antialias if mode in ("bilinear", "bicubic") else False), ) except TypeError: out = F.interpolate( img.float(), size=size, mode=mode, align_corners=False if mode in ("bilinear", "bicubic") else None, ) return out.to(dtype) def _resize(h: torch.Tensor, scale: float, mode: str, antialias: bool, min_feat: int) -> torch.Tensor: if scale == 1.0: return h size = _safe_size(h, scale, min_feat) return _interpolate(h, size, mode, antialias) def _parse_number_list(text: str, as_int: bool = False) -> List[float]: text = (text or "").replace("\n", " ") vals: List[float] = [] for chunk in text.split(";"): s = chunk.strip() if not s: continue if "/" in s: a, b = s.split("/", 1) v = float(a.strip()) / float(b.strip()) else: v = float(s) vals.append(v) if not vals: raise ValueError("Список пуст.") return [int(round(v)) for v in vals] if as_int else vals # -------------------------- Данные -------------------------- @dataclass class DSHFAction: enable: bool timestep: float depth: int scale: float # -------------------------- Скрипт -------------------------- class DSHF(scripts.Script): # Счётчики текущего прохода (класс-поля нужны U-Net'у) currentBlock: int = 0 currentConv: int = 0 currentTimestep: float = 1000.0 # Глобальные параметры (класс-поля: U-Net читает их как DSHF.*) enabled: bool = True interp_method: str = "bicubic" interp_antialias: bool = True channels_last: bool = False min_feature_size: int = 8 # Кривая curve_enable: bool = False curve_type: str = "linear" curve_t_start: float = 800.0 curve_t_end: float = 200.0 curve_scale_start: float = 1.0 curve_scale_end: float = 1.0 curve_alpha: float = 0.5 auto_end_enable: bool = False auto_end_strength: float = 0.35 # Пороговые правила actions: List[DSHFAction] = [] # Experimental (класс-поля для доступа из U-Net) exp_section_enable: bool = False exp_enable: bool = False exp_timestep: float = 625.0 exp_scales: List[float] = [1.0] exp_dilations: List[int] = [1] exp_in_muls: List[float] = [1.0] exp_out_muls: List[float] = [1.0] exp_cfg_muls: List[float] = [1.0] # ---------------- UI ---------------- def title(self): return "Deep Shrink Hires.fix (RU++ v2.3.1)" def show(self, is_img2img): return scripts.AlwaysVisible def ui(self, is_img2img): with gr.Tabs(): with gr.TabItem("Настройки"): Enable_Ext = gr.Checkbox(value=True, label="Включить расширение") with gr.Accordion("Пороги (до 8 правил)", open=True): Rule_Count = gr.Slider(1, 8, value=2, step=1, label="Сколько правил использовать") rules = [] defaults = [ (True, 625, 3, 2.0), (True, 0, 3, 2.0), (False, 900, 3, 2.0), (False, 650, 3, 2.0), (False, 900, 3, 2.0), (False, 650, 3, 2.0), (False, 900, 3, 2.0), (False, 650, 3, 2.0), ] for i, (en, ts, dp, sc) in enumerate(defaults, start=1): with gr.Row(): rules.append(( gr.Checkbox(value=en, label=f"Правило {i}"), gr.Number(value=ts, label=f"Timestep {i}"), gr.Number(value=dp, label=f"Глубина блока {i}", precision=0), gr.Number(value=sc, label=f"Масштаб {i}"), )) with gr.Accordion("Глобальная кривая масштаба", open=False): Curve_Enable = gr.Checkbox(value=False, label="Включить кривую") Curve_Type = gr.Dropdown(choices=["linear", "cosine", "sigmoid"], value="linear", label="Тип") with gr.Row(): Curve_t_start = gr.Number(value=800, label="t_start") Curve_t_end = gr.Number(value=200, label="t_end") with gr.Row(): Curve_scale_start = gr.Number(value=1.0, label="scale_start") Curve_scale_end = gr.Number(value=1.0, label="scale_end") Curve_alpha = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="alpha компенсации") Min_feature = gr.Slider(2, 64, value=8, step=1, label="Мин. размер фичей") with gr.Row(): Auto_end_enable = gr.Checkbox(value=False, label="Автоподбор scale_end по целевому разрешению") Auto_end_strength = gr.Slider(0.0, 1.0, value=0.35, step=0.05, label="Сила автоподбора") with gr.Accordion("Выполнение", open=False): Interp_Method = gr.Dropdown(choices=["nearest", "bilinear", "bicubic", "area"], value="bicubic", label="Интерполяция") Interp_AA = gr.Checkbox(value=True, label="Антиалиасинг для bilinear/bicubic") Channels_Last = gr.Checkbox(value=False, label="Оптимизация channels_last") with gr.Accordion("Экспериментальные (пер-свёрточные)", open=False): Exp_Section_Enable = gr.Checkbox(value=False, label="Включить секцию") with gr.Group(visible=False) as ExpGrp: Exp_Enable = gr.Checkbox(value=False, label="Активировать экспериментальное ядро") Exp_Timestep = gr.Number(value=625, label="Пороговой timestep для эксперимента") Exp_Scales = gr.Textbox(value="1", lines=2, label="Масштабы по свёрткам") Exp_Dilations = gr.Textbox(value="1", lines=1, label="Дилатации по свёрткам (целые)") Exp_InMuls = gr.Textbox(value="1", lines=1, label="Входные умножители по блокам") Exp_OutMuls = gr.Textbox(value="1", lines=1, label="Выходные умножители по блокам") Exp_CFGMuls = gr.Textbox(value="1", lines=1, label="CFG-множители по блокам") Exp_Section_Enable.change(lambda v: gr.update(visible=bool(v)), Exp_Section_Enable, ExpGrp) with gr.Accordion("Пресеты", open=False): Preset_Action = gr.Dropdown(choices=["Сохранить", "Загрузить"], value="Загрузить", label="Действие") Preset_Name = gr.Textbox(value="", label="Имя пресета") Preset_JSON = gr.Textbox(value="", lines=6, label="Профиль в JSON (для импорта/экспорта)") with gr.TabItem("Справка"): gr.Markdown(""" **LTS** — безопасное масштабирование только на границах блоков. **Experimental** — опционально: пер-свёрточные масштабы, дилатации, In/Out и CFG-мультипликаторы. """) flat = [Enable_Ext, Rule_Count] for row in rules: flat += list(row) flat += [ Curve_Enable, Curve_Type, Curve_t_start, Curve_t_end, Curve_scale_start, Curve_scale_end, Curve_alpha, Min_feature, Auto_end_enable, Auto_end_strength, Interp_Method, Interp_AA, Channels_Last, Exp_Section_Enable, Exp_Enable, Exp_Timestep, Exp_Scales, Exp_Dilations, Exp_InMuls, Exp_OutMuls, Exp_CFGMuls, Preset_Action, Preset_Name, Preset_JSON, ] return flat # ---------------- Исполнение ---------------- def _reset_instance_defaults(self): # только для локальной сборки входных значений; класс-поля перезапишем ниже self._inst_actions: List[DSHFAction] = [] self._inst_curve_enable = False self._inst_curve_type = "linear" self._inst_curve_t_start = 800.0 self._inst_curve_t_end = 200.0 self._inst_curve_scale_start = 1.0 self._inst_curve_scale_end = 1.0 self._inst_curve_alpha = 0.5 self._inst_auto_end_enable = False self._inst_auto_end_strength = 0.35 self._inst_interp_method = "bicubic" self._inst_interp_antialias = True self._inst_channels_last = False self._inst_min_feature_size = 8 self._inst_exp_section_enable = False self._inst_exp_enable = False self._inst_exp_timestep = 625.0 self._inst_exp_scales = [1.0] self._inst_exp_dilations = [1] self._inst_exp_in_muls = [1.0] self._inst_exp_out_muls = [1.0] self._inst_exp_cfg_muls = [1.0] @staticmethod def _curve_weight(ts: float) -> Optional[float]: if not DSHF.curve_enable: return None t0, t1 = DSHF.curve_t_start, DSHF.curve_t_end if t0 == t1: return DSHF.curve_scale_end x = _clamp((ts - t1) / (t0 - t1), 0.0, 1.0) if DSHF.curve_type == "linear": w = x elif DSHF.curve_type == "cosine": w = 0.5 - 0.5 * torch.cos(torch.tensor(x) * torch.pi).item() else: w = 1.0 / (1.0 + torch.exp(torch.tensor(-10.0 * (x - 0.5)))).item() s = DSHF.curve_scale_start + (DSHF.curve_scale_end - DSHF.curve_scale_start) * w return _clamp(float(s), 0.25, 4.0) @staticmethod def _block_scale(depth: int, ts: float) -> Optional[float]: rule_scale: Optional[float] = None for a in DSHF.actions: if a.enable and a.depth == depth and a.timestep <= ts: rule_scale = a.scale break curve_scale = DSHF._curve_weight(ts) if rule_scale is None and curve_scale is None: return None if rule_scale is None: return curve_scale if curve_scale is None: return rule_scale return _clamp(rule_scale * curve_scale, 0.25, 4.0) def _auto_scale_end(self, p) -> Optional[float]: if not self._inst_auto_end_enable or not self._inst_curve_enable: return None try: bw, bh = int(getattr(p, "width", 0)), int(getattr(p, "height", 0)) if bw <= 0 or bh <= 0: return None tw, th = bw, bh if getattr(p, "enable_hr", False): hrx = int(getattr(p, "hr_resize_x", 0)) hry = int(getattr(p, "hr_resize_y", 0)) hrs = float(getattr(p, "hr_scale", 0.0) or 0.0) if hrx > 0 and hry > 0: tw, th = hrx, hry elif hrs > 0.0: tw, th = int(round(bw * hrs)), int(round(bh * hrs)) r = ((max(1, tw * th)) / max(1, bw * bh)) ** 0.5 if r <= 1.0: return None return _clamp(1.0 + _clamp(self._inst_auto_end_strength, 0.0, 1.0) * (r - 1.0), 1.0, 1.7) except Exception: return None def process(self, p, *args): # Используем только с нашим UNet if not isinstance(sd_unet.current_unet, DSHF.DeepShrinkHiresFixUNet): return it = iter(args) def nxt(): return next(it) # Глобальный тумблер enabled = bool(nxt()) if not enabled: DSHF.enabled = False return # Сбор значений в instance-поле, чтобы не держать мусор в классе self._reset_instance_defaults() # ---- Правила ---- rule_count = int(_clamp(_to_scalar(nxt()), 1, 8)) tmp_rules: List[DSHFAction] = [] for _ in range(8): en = bool(nxt()); ts = _to_scalar(nxt()); dp = int(_to_scalar(nxt())); sc = float(_to_scalar(nxt())) tmp_rules.append(DSHFAction(en, ts, dp, sc)) self._inst_actions = tmp_rules[:rule_count] # ---- Кривая ---- self._inst_curve_enable = bool(nxt()) self._inst_curve_type = str(nxt()) self._inst_curve_t_start = _to_scalar(nxt()) self._inst_curve_t_end = _to_scalar(nxt()) self._inst_curve_scale_start = float(_to_scalar(nxt())) self._inst_curve_scale_end = float(_to_scalar(nxt())) self._inst_curve_alpha = float(_clamp(_to_scalar(nxt()), 0.0, 1.0)) self._inst_min_feature_size = int(_clamp(_to_scalar(nxt()), 2, 256)) self._inst_auto_end_enable = bool(nxt()) self._inst_auto_end_strength = float(_clamp(_to_scalar(nxt()), 0.0, 1.0)) # ---- Выполнение ---- self._inst_interp_method = str(nxt()) self._inst_interp_antialias = bool(nxt()) self._inst_channels_last = bool(nxt()) # ---- Experimental ---- self._inst_exp_section_enable = bool(nxt()) self._inst_exp_enable = bool(nxt()) self._inst_exp_timestep = _to_scalar(nxt()) try: self._inst_exp_scales = list(map(float, _parse_number_list(str(nxt()), as_int=False))) except Exception: self._inst_exp_scales = [1.0] try: self._inst_exp_dilations = list(map(int, _parse_number_list(str(nxt()), as_int=True))) except Exception: self._inst_exp_dilations = [1] try: self._inst_exp_in_muls = list(map(float, _parse_number_list(str(nxt()), as_int=False))) except Exception: self._inst_exp_in_muls = [1.0] try: self._inst_exp_out_muls = list(map(float, _parse_number_list(str(nxt()), as_int=False))) except Exception: self._inst_exp_out_muls = [1.0] try: self._inst_exp_cfg_muls = list(map(float, _parse_number_list(str(nxt()), as_int=False))) except Exception: self._inst_exp_cfg_muls = [1.0] # ---- Пресеты ---- preset_action = str(nxt() or "") preset_name = str(nxt() or "").strip() preset_json = str(nxt() or "").strip() preset_path = os.path.join(os.path.dirname(__file__), "dshf_presets.json") if preset_action == "Сохранить" and preset_name: data = self._export_profile_instance() try: cur = {"version": 1, "presets": {}} if os.path.exists(preset_path): with open(preset_path, "r", encoding="utf-8") as f: cur = json.load(f) cur["presets"][preset_name] = data with open(preset_path, "w", encoding="utf-8") as f: json.dump(cur, f, ensure_ascii=False, indent=2) print(f"[DSHF] Пресет сохранён: {preset_name}") except Exception as e: print(f"[DSHF] Не удалось сохранить пресет: {e}") elif preset_action == "Загрузить": if preset_json: try: prof = json.loads(preset_json) self._apply_profile_instance(prof) print("[DSHF] Профиль применён из JSON") except Exception as e: print(f"[DSHF] Ошибка JSON: {e}") elif preset_name: try: with open(preset_path, "r", encoding="utf-8") as f: cur = json.load(f) prof = cur.get("presets", {}).get(preset_name) if prof: self._apply_profile_instance(prof) print(f"[DSHF] Профиль загружен: {preset_name}") except Exception as e: print(f"[DSHF] Не удалось загрузить пресет: {e}") # Автоподбор конца кривой auto = self._auto_scale_end(p) if auto is not None: self._inst_curve_scale_end = float(auto) # -------- СИНХРОНИЗАЦИЯ instance → class (то, что читает U-Net) -------- DSHF.enabled = True DSHF.actions = list(self._inst_actions) DSHF.curve_enable = bool(self._inst_curve_enable) DSHF.curve_type = str(self._inst_curve_type) DSHF.curve_t_start = float(self._inst_curve_t_start) DSHF.curve_t_end = float(self._inst_curve_t_end) DSHF.curve_scale_start = float(self._inst_curve_scale_start) DSHF.curve_scale_end = float(self._inst_curve_scale_end) DSHF.curve_alpha = float(self._inst_curve_alpha) DSHF.min_feature_size = int(self._inst_min_feature_size) DSHF.auto_end_enable = bool(self._inst_auto_end_enable) DSHF.auto_end_strength = float(self._inst_auto_end_strength) DSHF.interp_method = str(self._inst_interp_method) DSHF.interp_antialias = bool(self._inst_interp_antialias) DSHF.channels_last = bool(self._inst_channels_last) # experimental DSHF.exp_section_enable = bool(self._inst_exp_section_enable) DSHF.exp_enable = bool(self._inst_exp_section_enable and self._inst_exp_enable) DSHF.exp_timestep = float(self._inst_exp_timestep) DSHF.exp_scales = list(self._inst_exp_scales) DSHF.exp_dilations = list(self._inst_exp_dilations) DSHF.exp_in_muls = list(self._inst_exp_in_muls) DSHF.exp_out_muls = list(self._inst_exp_out_muls) DSHF.exp_cfg_muls = list(self._inst_exp_cfg_muls) # --------- Профили (instance-вариант, чтобы сохранять то, что в UI) --------- def _export_profile_instance(self) -> Dict[str, Any]: return { "actions": [a.__dict__ for a in self._inst_actions], "curve": dict( enable=self._inst_curve_enable, type=self._inst_curve_type, t_start=self._inst_curve_t_start, t_end=self._inst_curve_t_end, scale_start=self._inst_curve_scale_start, scale_end=self._inst_curve_scale_end, alpha=self._inst_curve_alpha, min_feature=self._inst_min_feature_size, auto_end_enable=self._inst_auto_end_enable, auto_end_strength=self._inst_auto_end_strength, ), "runtime": dict( interp_method=self._inst_interp_method, interp_antialias=self._inst_interp_antialias, channels_last=self._inst_channels_last ), "experimental": dict( section_enable=self._inst_exp_section_enable, enable=self._inst_exp_enable, timestep=self._inst_exp_timestep, scales=self._inst_exp_scales, dilations=self._inst_exp_dilations, in_muls=self._inst_exp_in_muls, out_muls=self._inst_exp_out_muls, cfg_muls=self._inst_exp_cfg_muls ), } def _apply_profile_instance(self, data: Dict[str, Any]) -> None: try: self._inst_actions = [DSHFAction(bool(a.get("enable", True)), float(a.get("timestep", 0)), int(a.get("depth", 0)), float(a.get("scale", 1.0))) for a in data.get("actions", [])] c = data.get("curve", {}) self._inst_curve_enable = bool(c.get("enable", False)) self._inst_curve_type = str(c.get("type", "linear")) self._inst_curve_t_start = float(c.get("t_start", 800)) self._inst_curve_t_end = float(c.get("t_end", 200)) self._inst_curve_scale_start = float(c.get("scale_start", 1.0)) self._inst_curve_scale_end = float(c.get("scale_end", 1.0)) self._inst_curve_alpha = float(_clamp(float(c.get("alpha", 0.5)), 0.0, 1.0)) self._inst_min_feature_size = int(_clamp(float(c.get("min_feature", 8)), 2, 256)) self._inst_auto_end_enable = bool(c.get("auto_end_enable", False)) self._inst_auto_end_strength = float(_clamp(float(c.get("auto_end_strength", 0.35)), 0.0, 1.0)) r = data.get("runtime", {}) self._inst_interp_method = str(r.get("interp_method", self._inst_interp_method)) self._inst_interp_antialias = bool(r.get("interp_antialias", self._inst_interp_antialias)) self._inst_channels_last = bool(r.get("channels_last", self._inst_channels_last)) e = data.get("experimental", {}) self._inst_exp_section_enable = bool(e.get("section_enable", False)) self._inst_exp_enable = bool(e.get("enable", False)) self._inst_exp_timestep = float(e.get("timestep", 625)) self._inst_exp_scales = list(map(float, e.get("scales", [1.0]))) self._inst_exp_dilations = list(map(int, e.get("dilations", [1]))) self._inst_exp_in_muls = list(map(float, e.get("in_muls", [1.0]))) self._inst_exp_out_muls = list(map(float, e.get("out_muls", [1.0]))) self._inst_exp_cfg_muls = list(map(float, e.get("cfg_muls", [1.0]))) except Exception as ex: print(f"[DSHF] Ошибка применения профиля: {ex}") # --------- Обёртки Conv2d (работают только при включённом experimental) --------- class DSHF_Scale(torch.nn.Module): def __init__(self, conv2d_ref: List[torch.nn.Conv2d], get_rt): super().__init__() self.conv2d_ref = conv2d_ref self.get_rt = get_rt # -> (mode, aa, min_feat) def forward(self, h: torch.Tensor): if not DSHF.exp_enable or DSHF.currentTimestep < DSHF.exp_timestep: return h mode, aa, min_feat = self.get_rt() idx = DSHF.currentConv pre_scale = 1.0 / _get_or_last(DSHF.exp_scales, idx, 1.0) if pre_scale != 1.0: h = _resize(h, pre_scale, mode, aa, min_feat) conv = self.conv2d_ref[0] k = conv.kernel_size if isinstance(conv.kernel_size, tuple) else (conv.kernel_size, conv.kernel_size) if max(k) > 1: dil = int(_get_or_last(DSHF.exp_dilations, idx, 1)) conv.dilation = (dil, dil); conv.padding = (dil, dil) else: conv.dilation = (1, 1); conv.padding = (0, 0) return h class DSHF_Unscale(torch.nn.Module): def __init__(self, get_rt): super().__init__() self.get_rt = get_rt def forward(self, h: torch.Tensor): if not DSHF.exp_enable or DSHF.currentTimestep < DSHF.exp_timestep: DSHF.currentConv += 1; return h mode, aa, min_feat = self.get_rt() idx = DSHF.currentConv post_scale = _get_or_last(DSHF.exp_scales, idx, 1.0) if post_scale != 1.0: h = _resize(h, post_scale, mode, aa, min_feat) alpha = float(DSHF.curve_alpha) if alpha != 0.0: h = h * (post_scale ** alpha) DSHF.currentConv += 1 return h class DSHF_InMul(torch.nn.Module): def forward(self, h: torch.Tensor): if not DSHF.exp_enable or DSHF.currentTimestep < DSHF.exp_timestep: return h mul = _get_or_last(DSHF.exp_in_muls, DSHF.currentBlock, 1.0) return h if mul == 1.0 else h * float(mul) class DSHF_OutMul(torch.nn.Module): def forward(self, h: torch.Tensor): if not DSHF.exp_enable or DSHF.currentTimestep < DSHF.exp_timestep: return h mul = _get_or_last(DSHF.exp_out_muls, DSHF.currentBlock, 1.0) return h if mul == 1.0 else h * float(mul) # ---------------- Подменённый U-Net ---------------- class DeepShrinkHiresFixUNet(sd_unet.SdUnet): def __init__(self, _model): super().__init__() self.model = _model.to(devices.device) getter = lambda: (DSHF.interp_method, DSHF.interp_antialias, DSHF.min_feature_size) # Оборачивание слоёв for i, input_block in enumerate(self.model.input_blocks): for j, layer in enumerate(input_block): if isinstance(layer, ResBlock): for k, in_layer in enumerate(layer.in_layers): if isinstance(in_layer, torch.nn.Conv2d): self.model.input_blocks[i][j].in_layers[k] = torch.nn.Sequential( DSHF.DSHF_Scale([in_layer], getter), in_layer, DSHF.DSHF_Unscale(getter), DSHF.DSHF_InMul() ) for k, out_layer in enumerate(layer.out_layers): if isinstance(out_layer, torch.nn.Conv2d): self.model.input_blocks[i][j].out_layers[k] = torch.nn.Sequential( DSHF.DSHF_Scale([out_layer], getter), out_layer, DSHF.DSHF_Unscale(getter), DSHF.DSHF_OutMul() ) else: if isinstance(layer, torch.nn.Conv2d): self.model.input_blocks[i][j] = torch.nn.Sequential( DSHF.DSHF_Scale([layer], getter), layer, DSHF.DSHF_Unscale(getter) ) if isinstance(layer, Downsample): layer.op = torch.nn.Sequential(DSHF.DSHF_Scale([layer.op], getter), layer.op, DSHF.DSHF_Unscale(getter)) if isinstance(layer, Upsample) and hasattr(layer, "conv") and isinstance(layer.conv, torch.nn.Conv2d): layer.conv = torch.nn.Sequential(DSHF.DSHF_Scale([layer.conv], getter), layer.conv, DSHF.DSHF_Unscale(getter)) for j, layer in enumerate(self.model.middle_block): if isinstance(layer, ResBlock): for k, in_layer in enumerate(layer.in_layers): if isinstance(in_layer, torch.nn.Conv2d): self.model.middle_block[j].in_layers[k] = torch.nn.Sequential( DSHF.DSHF_Scale([in_layer], getter), in_layer, DSHF.DSHF_Unscale(getter), DSHF.DSHF_InMul() ) for k, out_layer in enumerate(layer.out_layers): if isinstance(out_layer, torch.nn.Conv2d): self.model.middle_block[j].out_layers[k] = torch.nn.Sequential( DSHF.DSHF_Scale([out_layer], getter), out_layer, DSHF.DSHF_Unscale(getter), DSHF.DSHF_OutMul() ) else: if isinstance(layer, torch.nn.Conv2d): self.model.middle_block[j] = torch.nn.Sequential( DSHF.DSHF_Scale([layer], getter), layer, DSHF.DSHF_Unscale(getter) ) if isinstance(layer, Downsample): layer.op = torch.nn.Sequential(DSHF.DSHF_Scale([layer.op], getter), layer.op, DSHF.DSHF_Unscale(getter)) if isinstance(layer, Upsample) and hasattr(layer, "conv") and isinstance(layer.conv, torch.nn.Conv2d): layer.conv = torch.nn.Sequential(DSHF.DSHF_Scale([layer.conv], getter), layer.conv, DSHF.DSHF_Unscale(getter)) for i, output_block in enumerate(self.model.output_blocks): for j, layer in enumerate(output_block): if isinstance(layer, ResBlock): for k, in_layer in enumerate(layer.in_layers): if isinstance(in_layer, torch.nn.Conv2d): self.model.output_blocks[i][j].in_layers[k] = torch.nn.Sequential( DSHF.DSHF_Scale([in_layer], getter), in_layer, DSHF.DSHF_Unscale(getter), DSHF.DSHF_InMul() ) for k, out_layer in enumerate(layer.out_layers): if isinstance(out_layer, torch.nn.Conv2d): self.model.output_blocks[i][j].out_layers[k] = torch.nn.Sequential( DSHF.DSHF_Scale([out_layer], getter), out_layer, DSHF.DSHF_Unscale(getter), DSHF.DSHF_OutMul() ) else: if isinstance(layer, torch.nn.Conv2d): self.model.output_blocks[i][j] = torch.nn.Sequential( DSHF.DSHF_Scale([layer], getter), layer, DSHF.DSHF_Unscale(getter) ) if isinstance(layer, Downsample): layer.op = torch.nn.Sequential(DSHF.DSHF_Scale([layer.op], getter), layer.op, DSHF.DSHF_Unscale(getter)) if isinstance(layer, Upsample) and hasattr(layer, "conv") and isinstance(layer.conv, torch.nn.Conv2d): layer.conv = torch.nn.Sequential(DSHF.DSHF_Scale([layer.conv], getter), layer.conv, DSHF.DSHF_Unscale(getter)) for i, module in enumerate(self.model.out): if isinstance(module, torch.nn.Conv2d): self.model.out[i] = torch.nn.Sequential(DSHF.DSHF_Scale([module], getter), module, DSHF.DSHF_Unscale(getter)) def forward(self, x, timesteps, context, y=None, **kwargs): assert (y is not None) == (self.model.num_classes is not None), "must specify y iff class-conditional" if DSHF.channels_last: x = x.contiguous(memory_format=torch.channels_last) hs = [] emb = self.model.time_embed(timestep_embedding(timesteps, self.model.model_channels, repeat_only=False)) if self.model.num_classes is not None: assert y.shape[0] == x.shape[0] emb = emb + self.model.label_emb(y) h = x.type(self.model.dtype) depth = 0 DSHF.currentBlock = 0 DSHF.currentConv = 0 DSHF.currentTimestep = float(timesteps.detach().float().min().item()) # Входные блоки for module in self.model.input_blocks: s = DSHF._block_scale(depth, DSHF.currentTimestep) if s is not None: h = _resize(h, 1.0 / float(s), DSHF.interp_method, DSHF.interp_antialias, DSHF.min_feature_size) context_tmp = context if DSHF.exp_enable and DSHF.currentTimestep >= DSHF.exp_timestep: cfg_mul = _get_or_last(DSHF.exp_cfg_muls, DSHF.currentBlock, 1.0) if cfg_mul != 1.0: context_tmp = context * float(cfg_mul) h = module(h, emb, context_tmp) hs.append(h) depth += 1 DSHF.currentBlock += 1 # Средний блок s = DSHF._block_scale(depth, DSHF.currentTimestep) if s is not None: h = _resize(h, 1.0 / float(s), DSHF.interp_method, DSHF.interp_antialias, DSHF.min_feature_size) context_tmp = context if DSHF.exp_enable and DSHF.currentTimestep >= DSHF.exp_timestep: cfg_mul = _get_or_last(DSHF.exp_cfg_muls, DSHF.currentBlock, 1.0) if cfg_mul != 1.0: context_tmp = context * float(cfg_mul) h = self.model.middle_block(h, emb, context_tmp) s = DSHF._block_scale(depth, DSHF.currentTimestep) if s is not None: h = _resize(h, float(s), DSHF.interp_method, DSHF.interp_antialias, DSHF.min_feature_size) DSHF.currentBlock += 1 # Выходные блоки for module in self.model.output_blocks: depth -= 1 h = torch.cat([h, hs.pop()], dim=1) context_tmp = context if DSHF.exp_enable and DSHF.currentTimestep >= DSHF.exp_timestep: cfg_mul = _get_or_last(DSHF.exp_cfg_muls, DSHF.currentBlock, 1.0) if cfg_mul != 1.0: context_tmp = context * float(cfg_mul) h = module(h, emb, context_tmp) s = DSHF._block_scale(depth, DSHF.currentTimestep) if s is not None: h = _resize(h, float(s), DSHF.interp_method, DSHF.interp_antialias, DSHF.min_feature_size) DSHF.currentBlock += 1 h = h.type(x.dtype) return self.model.id_predictor(h) if self.model.predict_codebook_ids else self.model.out(h) # Регистрация варианта U-Net DeepShrinkHiresFixUNetOption = sd_unet.SdUnetOption() DeepShrinkHiresFixUNetOption.label = "Deep Shrink Hires.fix" DeepShrinkHiresFixUNetOption.create_unet = lambda: DSHF.DeepShrinkHiresFixUNet(shared.sd_model.model.diffusion_model) script_callbacks.on_list_unets(lambda unets: unets.append(DeepShrinkHiresFixUNetOption))