dikdimon commited on
Commit
66fad01
·
verified ·
1 Parent(s): f6f7182

Upload sd-webui-progressive-growing using SD-Hub

Browse files
sd-webui-progressive-growing/scripts/progressive_growing_always.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """sd-webui-progressive-growing
2
+
3
+ Always-visible UI + runtime patch for AUTOMATIC1111.
4
+
5
+ This extension ports the user's provided implementation of `sample_progressive()` 1:1.
6
+ It does NOT modify core files on disk; instead it monkey-patches
7
+ `modules.processing.StableDiffusionProcessingTxt2Img.sample` at runtime.
8
+
9
+ UI is AlwaysVisible (not in the Scripts dropdown).
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import gradio as gr
15
+
16
+ from modules import scripts, sd_samplers, devices
17
+ from modules import processing as processing_mod
18
+ from modules.processing import create_random_tensors, decode_latent_batch, opt_C, opt_f
19
+
20
+
21
+ # -----------------------------
22
+ # Progressive Growing versions
23
+ # -----------------------------
24
+
25
+ def sample_progressive_v1_exact(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
26
+ """Exact copy of the user-provided implementation (processing.py::sample_progressive)."""
27
+
28
+ import numpy as np
29
+ import torch
30
+
31
+ is_sdxl = getattr(self.sd_model, 'is_sdxl', False)
32
+
33
+ # 1) Больше НЕТ принудительного min_scale>=0.5 для SDXL:
34
+ min_scale = float(self.progressive_growing_min_scale)
35
+ max_scale = float(self.progressive_growing_max_scale)
36
+
37
+ # На всякий случай: если пользователь перепутал местами — делаем честный "рост"
38
+ # (если хочешь позволять "shrink", просто убери этот swap)
39
+ # if min_scale > max_scale:
40
+ # min_scale, max_scale = max_scale, min_scale
41
+
42
+ resolution_steps = np.linspace(min_scale, max_scale, int(self.progressive_growing_steps))
43
+
44
+ def _snap(v):
45
+ v_int = int(v)
46
+ v_int = max(opt_f, v_int)
47
+ v_int = (v_int // opt_f) * opt_f
48
+ return max(opt_f, v_int)
49
+
50
+ # 2) Стартовое разрешение
51
+ initial_width = _snap(self.width * resolution_steps[0])
52
+ initial_height = _snap(self.height * resolution_steps[0])
53
+
54
+ # 3) Начальный латент (noise)
55
+ x = create_random_tensors(
56
+ (opt_C, initial_height // opt_f, initial_width // opt_f),
57
+ seeds,
58
+ subseeds=subseeds,
59
+ subseed_strength=subseed_strength,
60
+ seed_resize_from_h=self.seed_resize_from_h,
61
+ seed_resize_from_w=self.seed_resize_from_w,
62
+ p=self
63
+ )
64
+
65
+ # 4) Первый проход sampler.sample()
66
+ samples = self.sampler.sample(
67
+ self,
68
+ x,
69
+ conditioning,
70
+ unconditional_conditioning,
71
+ image_conditioning=self.txt2img_image_conditioning(x)
72
+ )
73
+
74
+ total_stages = len(resolution_steps)
75
+
76
+ # 5) Прогрессивный рост
77
+ for i in range(1, total_stages):
78
+ target_width = _snap(self.width * resolution_steps[i])
79
+ target_height = _snap(self.height * resolution_steps[i])
80
+
81
+ # upscale latent
82
+ samples = torch.nn.functional.interpolate(
83
+ samples,
84
+ size=(target_height // opt_f, target_width // opt_f),
85
+ mode='bicubic',
86
+ align_corners=False
87
+ )
88
+
89
+ # 6) Refinement на каждом шаге (опционально)
90
+ if self.progressive_growing_refinement:
91
+ steps_for_refinement = max(1, self.steps // total_stages)
92
+
93
+ noise = create_random_tensors(
94
+ samples.shape[1:],
95
+ seeds,
96
+ subseeds=subseeds,
97
+ subseed_strength=subseed_strength,
98
+ seed_resize_from_h=self.seed_resize_from_h,
99
+ seed_resize_from_w=self.seed_resize_from_w,
100
+ p=self
101
+ )
102
+
103
+ decoded = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
104
+ decoded = torch.stack(decoded).float()
105
+ decoded = torch.clamp((decoded + 1.0) / 2.0, 0.0, 1.0)
106
+
107
+ source_img = decoded * 2.0 - 1.0
108
+ self.image_conditioning = self.img2img_image_conditioning(source_img, samples)
109
+ samples = self.sampler.sample_img2img(
110
+ self,
111
+ samples,
112
+ noise,
113
+ conditioning,
114
+ unconditional_conditioning,
115
+ steps=steps_for_refinement,
116
+ image_conditioning=self.image_conditioning
117
+ )
118
+
119
+ return samples
120
+
121
+
122
+ _VERSIONS = {
123
+ "v1 (exact)": sample_progressive_v1_exact,
124
+ }
125
+
126
+
127
+ # -----------------------------
128
+ # Runtime patching
129
+ # -----------------------------
130
+
131
+ _PATCHED = False
132
+ _ORIG_SAMPLE = None
133
+
134
+
135
+ def _apply_patch_once() -> None:
136
+ """Patch StableDiffusionProcessingTxt2Img.sample to route into sample_progressive_* when enabled."""
137
+
138
+ global _PATCHED, _ORIG_SAMPLE
139
+ if _PATCHED:
140
+ return
141
+
142
+ cls = getattr(processing_mod, 'StableDiffusionProcessingTxt2Img', None)
143
+ if cls is None:
144
+ return
145
+
146
+ # already patched by us (or another copy)
147
+ if getattr(cls, '_progressive_growing_ext_patched', False):
148
+ _PATCHED = True
149
+ return
150
+
151
+ _ORIG_SAMPLE = cls.sample
152
+
153
+ def _sample_wrapper(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
154
+ # Only intercept when user enabled the feature
155
+ if getattr(self, 'enable_progressive_growing', False):
156
+ # mirror the user code: sampler is created at the start of sample()
157
+ self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
158
+
159
+ # pick version (defaults to exact v1)
160
+ ver = getattr(self, 'progressive_growing_version', 'v1 (exact)')
161
+ fn = _VERSIONS.get(ver, sample_progressive_v1_exact)
162
+ return fn(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts)
163
+
164
+ # fallback to original behaviour (including its sampler creation)
165
+ return _ORIG_SAMPLE(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts)
166
+
167
+ cls.sample = _sample_wrapper
168
+ cls._progressive_growing_ext_patched = True
169
+ _PATCHED = True
170
+
171
+
172
+ # -----------------------------
173
+ # Always-visible UI script
174
+ # -----------------------------
175
+
176
+
177
+ class ProgressiveGrowingAlwaysVisible(scripts.Script):
178
+ def title(self):
179
+ return "Progressive Growing"
180
+
181
+ def show(self, is_img2img):
182
+ # Only for txt2img, always visible
183
+ return scripts.AlwaysVisible if not is_img2img else False
184
+
185
+ def ui(self, is_img2img):
186
+ with gr.Accordion("Progressive Growing", open=False):
187
+ enabled = gr.Checkbox(value=False, label="Enable")
188
+ version = gr.Dropdown(choices=list(_VERSIONS.keys()), value="v1 (exact)", label="Version")
189
+
190
+ min_scale = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.25, label="Min scale")
191
+ max_scale = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=1.0, label="Max scale")
192
+ steps = gr.Slider(minimum=2, maximum=16, step=1, value=4, label="Stages")
193
+ refinement = gr.Checkbox(value=True, label="Refinement between stages")
194
+
195
+ gr.Markdown(
196
+ "- Starts at Min scale, then increases latent resolution up to Max scale.\n"
197
+ "- Optional short img2img refinement at each stage.\n"
198
+ "- This implementation matches your provided code (v1 exact)."
199
+ )
200
+
201
+ return [enabled, version, min_scale, max_scale, steps, refinement]
202
+
203
+ def process(self, p, enabled, version, min_scale, max_scale, steps, refinement):
204
+ _apply_patch_once()
205
+
206
+ # store parameters on p (matching the reference implementation's attribute names)
207
+ p.enable_progressive_growing = bool(enabled)
208
+ p.progressive_growing_version = str(version)
209
+ p.progressive_growing_min_scale = float(min_scale)
210
+ p.progressive_growing_max_scale = float(max_scale)
211
+ p.progressive_growing_steps = int(steps)
212
+ p.progressive_growing_refinement = bool(refinement)
213
+
214
+ if p.enable_progressive_growing:
215
+ # Keep generation params so they show up in infotext (if UI/processing prints them)
216
+ try:
217
+ p.extra_generation_params["Progressive Growing"] = "True"
218
+ p.extra_generation_params["Min Scale"] = p.progressive_growing_min_scale
219
+ p.extra_generation_params["Max Scale"] = p.progressive_growing_max_scale
220
+ p.extra_generation_params["Progressive Growing Steps"] = p.progressive_growing_steps
221
+ p.extra_generation_params["Refinement"] = "True" if p.progressive_growing_refinement else None
222
+ p.extra_generation_params["PG Version"] = p.progressive_growing_version
223
+ except Exception:
224
+ # extra_generation_params may not exist in some contexts
225
+ pass