dikdimon commited on
Commit
030cb3d
·
verified ·
1 Parent(s): d39654b

Update 3-bmab/sd_bmab/sd_override/txt2img.py

Browse files
Files changed (1) hide show
  1. 3-bmab/sd_bmab/sd_override/txt2img.py +86 -29
3-bmab/sd_bmab/sd_override/txt2img.py CHANGED
@@ -144,38 +144,94 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
144
  return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
145
 
146
  def sample_progressive(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
147
- is_sdxl = getattr(self.sd_model, 'is_sdxl', False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- if is_sdxl:
150
- min_scale = max(0.5, self.progressive_growing_min_scale)
151
- else:
152
- min_scale = self.progressive_growing_min_scale
153
-
154
- resolution_steps = np.linspace(min_scale, self.progressive_growing_max_scale, self.progressive_growing_steps)
155
-
156
- initial_width = max(512 if is_sdxl else 64, int(self.width * resolution_steps[0]))
157
- initial_height = max(512 if is_sdxl else 64, int(self.height * resolution_steps[0]))
158
-
159
- x = create_random_tensors((opt_C, initial_height // opt_f, initial_width // opt_f), seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
160
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
161
-
162
- for i in range(1, len(resolution_steps)):
163
- target_width = int(self.width * resolution_steps[i])
164
- target_height = int(self.height * resolution_steps[i])
165
-
166
- if is_sdxl:
167
- target_width = max(512, min(1536, target_width))
168
- target_height = max(512, min(1536, target_height))
169
 
170
- samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode='bicubic', align_corners=False)
 
 
 
 
 
 
 
 
171
 
172
- if self.progressive_growing_refinement:
173
- steps_for_refinement = self.steps // len(resolution_steps)
174
- noise = create_random_tensors(samples.shape[1:], seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
175
- decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
176
- decoded_samples = torch.stack(decoded_samples).float()
177
- decoded_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
178
- self.image_conditioning = self.img2img_image_conditioning(decoded_samples * 2 - 1, samples)
179
 
180
  samples = self.sampler.sample_img2img(
181
  self,
@@ -188,6 +244,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
188
  )
189
 
190
  return samples
 
191
 
192
  def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
193
  if shared.state.interrupted:
 
144
  return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
145
 
146
  def sample_progressive(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
147
+ import numpy as np
148
+ import torch
149
+
150
+ # SDXL больше НЕ клампим: используем ровно то, что пришло из UI
151
+ min_scale = float(self.progressive_growing_min_scale)
152
+ max_scale = float(self.progressive_growing_max_scale)
153
+ steps_cnt = int(self.progressive_growing_steps)
154
+
155
+ # Если хочешь позволять "shrink", оставь как есть.
156
+ # Если нужен только рост, раскомментируй:
157
+ # if min_scale > max_scale:
158
+ # min_scale, max_scale = max_scale, min_scale
159
+
160
+ # Равномерные масштабы без SDXL-клампов
161
+ resolution_steps = np.linspace(min_scale, max_scale, steps_cnt)
162
+
163
+ # Вспомогательно: привести к кратности opt_f и не дать упасть до 0
164
+ def _snap(v: float) -> int:
165
+ from modules.processing import opt_f
166
+ x = int(v)
167
+ x = max(opt_f, x)
168
+ x = (x // opt_f) * opt_f
169
+ return max(opt_f, x)
170
+
171
+ # Стартовое разрешение (никаких 512/1536 клампов)
172
+ initial_width = _snap(self.width * resolution_steps[0])
173
+ initial_height = _snap(self.height * resolution_steps[0])
174
+
175
+ from modules.processing import opt_C, create_random_tensors, decode_latent_batch
176
+ from modules import devices
177
+
178
+ x = create_random_tensors(
179
+ (opt_C, initial_height // opt_f, initial_width // opt_f),
180
+ seeds,
181
+ subseeds=subseeds,
182
+ subseed_strength=subseed_strength,
183
+ seed_resize_from_h=self.seed_resize_from_h,
184
+ seed_resize_from_w=self.seed_resize_from_w,
185
+ p=self
186
+ )
187
+
188
+ samples = self.sampler.sample(
189
+ self,
190
+ x,
191
+ conditioning,
192
+ unconditional_conditioning,
193
+ image_conditioning=self.txt2img_image_conditioning(x)
194
+ )
195
+
196
+ total_stages = len(resolution_steps)
197
+
198
+ for i in range(1, total_stages):
199
+ target_width = _snap(self.width * resolution_steps[i])
200
+ target_height = _snap(self.height * resolution_steps[i])
201
+
202
+ # Ресэмпл латентов до следующего шага без SDXL-клампов
203
+ samples = torch.nn.functional.interpolate(
204
+ samples,
205
+ size=(target_height // opt_f, target_width // opt_f),
206
+ mode='bicubic',
207
+ align_corners=False
208
+ )
209
 
210
+ if self.progressive_growing_refinement:
211
+ # хотя бы 1 шаг на стадию
212
+ steps_for_refinement = max(1, self.steps // total_stages)
213
+
214
+ noise = create_random_tensors(
215
+ samples.shape[1:], # (C, H/8, W/8)
216
+ seeds,
217
+ subseeds=subseeds,
218
+ subseed_strength=subseed_strength,
219
+ seed_resize_from_h=self.seed_resize_from_h,
220
+ seed_resize_from_w=self.seed_resize_from_w,
221
+ p=self
222
+ )
 
 
 
 
 
 
 
223
 
224
+ decoded = decode_latent_batch(
225
+ self.sd_model,
226
+ samples,
227
+ target_device=devices.cpu,
228
+ check_for_nans=True
229
+ )
230
+ decoded = torch.stack(decoded).float()
231
+ decoded = torch.clamp((decoded + 1.0) / 2.0, 0.0, 1.0) # [0..1]
232
+ src = decoded * 2.0 - 1.0 # [-1..1]
233
 
234
+ self.image_conditioning = self.img2img_image_conditioning(src, samples)
 
 
 
 
 
 
235
 
236
  samples = self.sampler.sample_img2img(
237
  self,
 
244
  )
245
 
246
  return samples
247
+
248
 
249
  def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
250
  if shared.state.interrupted: