coralLight commited on
Commit
21110be
·
1 Parent(s): 1ee92d4

add inference

Browse files
Files changed (2) hide show
  1. app.py +20 -13
  2. customed_unipc_scheduler.py +17 -12
app.py CHANGED
@@ -198,20 +198,27 @@ def generate_image_with_steps(prompt, negative_prompt, seed, width, height, guid
198
  latent_model_input = pipe.scheduler.scale_model_input(latent_model_input , timestep=t)
199
  negative_prompts = '(worst quality:2), (low quality:2), (normal quality:2), bad anatomy, bad proportions, poorly drawn face, poorly drawn hands, missing fingers, extra limbs, blurry, pixelated, distorted, lowres, jpeg artifacts, watermark, signature, text, (deformed:1.5), (bad hands:1.3), overexposed, underexposed, censored, mutated, extra fingers, cloned face, bad eyes'
200
  negative_prompts = 1 * [negative_prompts]
201
-
 
202
  prompt_embeds, cond_kwargs = prepare_sdxl_pipeline_step_parameter(pipe
203
- , prompts
204
- , need_cfg=True
205
- , device=pipe.device
206
- , negative_prompt=negative_prompts
207
- , W=width
208
- , H=height)
209
- noise_pred = pipe.unet(latent_model_input
 
 
 
 
 
210
  , t
211
  , encoder_hidden_states=prompt_embeds.to(device=latents.device, dtype=latents.dtype)
212
  , added_cond_kwargs=cond_kwargs).sample
213
- uncond, cond = noise_pred.chunk(2)
214
- noise_pred = uncond + (cond - uncond) * guidance_scale
 
215
  latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample
216
  idx += 1
217
 
@@ -325,13 +332,13 @@ with gr.Blocks() as demo:
325
  guidance_scale = gr.Slider(
326
  label="Guidance scale",
327
  minimum=0.0,
328
- maximum=10.0,
329
  step=0.1,
330
- value=7.5, # Replace with defaults that work for your model
331
  )
332
 
333
  num_inference_steps = gr.Dropdown(
334
- choices=[6, 7, 8],
335
  value=8,
336
  label="Number of inference steps",
337
  )
 
198
  latent_model_input = pipe.scheduler.scale_model_input(latent_model_input , timestep=t)
199
  negative_prompts = '(worst quality:2), (low quality:2), (normal quality:2), bad anatomy, bad proportions, poorly drawn face, poorly drawn hands, missing fingers, extra limbs, blurry, pixelated, distorted, lowres, jpeg artifacts, watermark, signature, text, (deformed:1.5), (bad hands:1.3), overexposed, underexposed, censored, mutated, extra fingers, cloned face, bad eyes'
200
  negative_prompts = 1 * [negative_prompts]
201
+ use_afs = num_inference_steps < 7
202
+ use_free_predictor = False
203
  prompt_embeds, cond_kwargs = prepare_sdxl_pipeline_step_parameter(pipe
204
+ , prompts
205
+ , need_cfg=True
206
+ , device=pipe.device
207
+ , negative_prompt=negative_prompts
208
+ , W=width
209
+ , H=height)
210
+ if idx == 0 and use_afs:
211
+ noise_pred = latent_model_input * 0.975
212
+ elif idx == len(pipe.scheduler.timesteps) - 1 and use_free_predictor:
213
+ noise_pred = None
214
+ else:
215
+ noise_pred = pipe.unet(latent_model_input
216
  , t
217
  , encoder_hidden_states=prompt_embeds.to(device=latents.device, dtype=latents.dtype)
218
  , added_cond_kwargs=cond_kwargs).sample
219
+ if noise_pred is not None:
220
+ uncond, cond = noise_pred.chunk(2)
221
+ noise_pred = uncond + (cond - uncond) * guidance_scale
222
  latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample
223
  idx += 1
224
 
 
332
  guidance_scale = gr.Slider(
333
  label="Guidance scale",
334
  minimum=0.0,
335
+ maximum=6.0,
336
  step=0.1,
337
+ value=5.5, # Replace with defaults that work for your model
338
  )
339
 
340
  num_inference_steps = gr.Dropdown(
341
+ choices=[5, 6, 7, 8],
342
  value=8,
343
  label="Number of inference steps",
344
  )
customed_unipc_scheduler.py CHANGED
@@ -215,7 +215,8 @@ class CustomedUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
215
  skip_type: str = "customed_time_karras",
216
  denoise_to_zero: bool = False,
217
  rescale_betas_zero_snr: bool = False,
218
- use_afs: bool = False
 
219
  ):
220
 
221
  if self.config.use_beta_sigmas and not is_scipy_available():
@@ -238,6 +239,7 @@ class CustomedUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
238
  raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
239
 
240
  self.skip_type = skip_type
 
241
  self.use_afs = use_afs
242
  self.denoise_to_zero = denoise_to_zero
243
  if rescale_betas_zero_snr:
@@ -331,7 +333,7 @@ class CustomedUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
331
  ct_end = self._sigma_to_t(sigmas[9], log_sigmas)
332
  if self.denoise_to_zero:
333
  ct_real_end = self._sigma_to_t(sigmas[-1], log_sigmas)
334
- timesteps = self.get_sigmas_karras(9, ct_end, ct_start,rho=1.2, customed_final_sigma= ct_real_end if self.denoise_to_zero else None)
335
  elif N == 5:
336
  log_sigmas = np.log(sigmas)
337
  sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0)
@@ -339,7 +341,7 @@ class CustomedUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
339
  ct_end = self._sigma_to_t(sigmas[6], log_sigmas)
340
  if self.denoise_to_zero:
341
  ct_real_end = self._sigma_to_t(sigmas[-1], log_sigmas)
342
- timesteps = self.get_sigmas_karras(5, ct_end, ct_start,rho=1.2, customed_final_sigma= ct_real_end if self.denoise_to_zero else None)
343
  elif N == 6:
344
  log_sigmas = np.log(sigmas)
345
  sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0)
@@ -347,7 +349,7 @@ class CustomedUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
347
  ct_end = self._sigma_to_t(sigmas[6], log_sigmas)
348
  if self.denoise_to_zero:
349
  ct_real_end = self._sigma_to_t(sigmas[-1], log_sigmas)
350
- timesteps = self.get_sigmas_karras(6, ct_end, ct_start,rho=1.2, customed_final_sigma= ct_real_end if self.denoise_to_zero else None)
351
  elif N == 7:
352
  log_sigmas = np.log(sigmas)
353
  sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0)
@@ -355,7 +357,7 @@ class CustomedUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
355
  ct_end = self._sigma_to_t(sigmas[6], log_sigmas)
356
  if self.denoise_to_zero:
357
  ct_real_end = self._sigma_to_t(sigmas[-1], log_sigmas)
358
- timesteps = self.get_sigmas_karras(7, ct_end, ct_start,rho=1.2, customed_final_sigma= ct_real_end if self.denoise_to_zero else None)
359
  elif N == 8:
360
  log_sigmas = np.log(sigmas).copy()
361
  sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0)
@@ -363,10 +365,10 @@ class CustomedUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
363
  ct_end = self._sigma_to_t(sigmas[6], log_sigmas)
364
  if self.denoise_to_zero:
365
  ct_real_end = self._sigma_to_t(sigmas[-1], log_sigmas)
366
- timesteps = self.get_sigmas_karras(8, ct_end, ct_start,rho=1.2, customed_final_sigma= ct_real_end if self.denoise_to_zero else None)
367
 
368
- if self.use_afs:
369
- np.insert(timesteps,1,(timesteps[0]+timesteps[1]) / 2)
370
 
371
 
372
  timesteps_tmp = copy.deepcopy(timesteps)
@@ -523,6 +525,9 @@ class CustomedUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
523
  sigma = self.sigmas[self.step_index]
524
  alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
525
 
 
 
 
526
  if self.predict_x0:
527
  if self.config.prediction_type == "epsilon":
528
  x0_pred = (sample - sigma_t * model_output) / alpha_t
@@ -560,7 +565,7 @@ class CustomedUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
560
 
561
  def multistep_uni_p_bh_update(
562
  self,
563
- model_output: torch.Tensor,
564
  *args,
565
  sample: torch.Tensor = None,
566
  order: int = None,
@@ -897,7 +902,7 @@ class CustomedUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
897
  )
898
 
899
  model_output_convert = self.convert_model_output(model_output, sample=sample)
900
- if use_corrector:
901
  sample = self.multistep_uni_c_bh_update(
902
  this_model_output=model_output_convert,
903
  last_sample=self.last_sample,
@@ -908,8 +913,8 @@ class CustomedUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
908
  for i in range(self.solver_order - 1):
909
  self.model_outputs[i] = self.model_outputs[i + 1]
910
  self.timestep_list[i] = self.timestep_list[i + 1]
911
-
912
- self.model_outputs[-1] = model_output_convert
913
  self.timestep_list[-1] = timestep
914
 
915
  if self.config.lower_order_final:
 
215
  skip_type: str = "customed_time_karras",
216
  denoise_to_zero: bool = False,
217
  rescale_betas_zero_snr: bool = False,
218
+ use_afs: bool = False,
219
+ use_free_predictor = False
220
  ):
221
 
222
  if self.config.use_beta_sigmas and not is_scipy_available():
 
239
  raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
240
 
241
  self.skip_type = skip_type
242
+ self.use_free_predictor = use_free_predictor
243
  self.use_afs = use_afs
244
  self.denoise_to_zero = denoise_to_zero
245
  if rescale_betas_zero_snr:
 
333
  ct_end = self._sigma_to_t(sigmas[9], log_sigmas)
334
  if self.denoise_to_zero:
335
  ct_real_end = self._sigma_to_t(sigmas[-1], log_sigmas)
336
+ timesteps = self.get_sigmas_karras(9 + (1 if self.use_free_predictor else 0), ct_end, ct_start,rho=1.2, customed_final_sigma= ct_real_end if self.denoise_to_zero else None)
337
  elif N == 5:
338
  log_sigmas = np.log(sigmas)
339
  sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0)
 
341
  ct_end = self._sigma_to_t(sigmas[6], log_sigmas)
342
  if self.denoise_to_zero:
343
  ct_real_end = self._sigma_to_t(sigmas[-1], log_sigmas)
344
+ timesteps = self.get_sigmas_karras(5 + (1 if self.use_afs else 0) + (1 if self.use_free_predictor else 0), ct_end, ct_start,rho=1.2, customed_final_sigma= ct_real_end if self.denoise_to_zero else None)
345
  elif N == 6:
346
  log_sigmas = np.log(sigmas)
347
  sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0)
 
349
  ct_end = self._sigma_to_t(sigmas[6], log_sigmas)
350
  if self.denoise_to_zero:
351
  ct_real_end = self._sigma_to_t(sigmas[-1], log_sigmas)
352
+ timesteps = self.get_sigmas_karras(6 + (1 if self.use_afs else 0) + (1 if self.use_free_predictor else 0), ct_end, ct_start,rho=1.2, customed_final_sigma= ct_real_end if self.denoise_to_zero else None)
353
  elif N == 7:
354
  log_sigmas = np.log(sigmas)
355
  sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0)
 
357
  ct_end = self._sigma_to_t(sigmas[6], log_sigmas)
358
  if self.denoise_to_zero:
359
  ct_real_end = self._sigma_to_t(sigmas[-1], log_sigmas)
360
+ timesteps = self.get_sigmas_karras(7 + (1 if self.use_free_predictor else 0), ct_end, ct_start,rho=1.2, customed_final_sigma= ct_real_end if self.denoise_to_zero else None)
361
  elif N == 8:
362
  log_sigmas = np.log(sigmas).copy()
363
  sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0)
 
365
  ct_end = self._sigma_to_t(sigmas[6], log_sigmas)
366
  if self.denoise_to_zero:
367
  ct_real_end = self._sigma_to_t(sigmas[-1], log_sigmas)
368
+ timesteps = self.get_sigmas_karras(8 + (1 if self.use_free_predictor else 0), ct_end, ct_start,rho=1.2, customed_final_sigma= ct_real_end if self.denoise_to_zero else None)
369
 
370
+ if self.use_afs and N > 6:
371
+ timesteps = np.insert(timesteps,1,(timesteps[0]+timesteps[1]) / 2)
372
 
373
 
374
  timesteps_tmp = copy.deepcopy(timesteps)
 
525
  sigma = self.sigmas[self.step_index]
526
  alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
527
 
528
+ if model_output is None:
529
+ return None
530
+
531
  if self.predict_x0:
532
  if self.config.prediction_type == "epsilon":
533
  x0_pred = (sample - sigma_t * model_output) / alpha_t
 
565
 
566
  def multistep_uni_p_bh_update(
567
  self,
568
+ model_output: torch.Tensor = None,
569
  *args,
570
  sample: torch.Tensor = None,
571
  order: int = None,
 
902
  )
903
 
904
  model_output_convert = self.convert_model_output(model_output, sample=sample)
905
+ if use_corrector and model_output_convert is not None:
906
  sample = self.multistep_uni_c_bh_update(
907
  this_model_output=model_output_convert,
908
  last_sample=self.last_sample,
 
913
  for i in range(self.solver_order - 1):
914
  self.model_outputs[i] = self.model_outputs[i + 1]
915
  self.timestep_list[i] = self.timestep_list[i + 1]
916
+ if model_output_convert is not None:
917
+ self.model_outputs[-1] = model_output_convert
918
  self.timestep_list[-1] = timestep
919
 
920
  if self.config.lower_order_final: