BiliSakura commited on
Commit
969a448
·
verified ·
1 Parent(s): 02a21e2

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. ADM-G-256/pipeline.py +15 -0
  2. ADM-G-512/pipeline.py +15 -0
ADM-G-256/pipeline.py CHANGED
@@ -108,6 +108,20 @@ class ADMPipeline(DiffusionPipeline):
108
  def _is_ddim_like(step_params: Set[str]) -> bool:
109
  return "eta" in step_params
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  @staticmethod
112
  def _expand_timestep(timestep, batch: int, device: torch.device) -> torch.Tensor:
113
  if not torch.is_tensor(timestep):
@@ -242,6 +256,7 @@ class ADMPipeline(DiffusionPipeline):
242
  "Use a DDPM/DDIM-compatible scheduler or disable classifier guidance."
243
  )
244
 
 
245
  latents = scheduler.step(step_model_output, timestep, latents, return_dict=True, **extra_step_kwargs).prev_sample
246
 
247
  image = latents if output_type == "latent" else (latents / 2 + 0.5).clamp(0, 1)
 
108
  def _is_ddim_like(step_params: Set[str]) -> bool:
109
  return "eta" in step_params
110
 
111
+ @staticmethod
112
+ def _prepare_model_output_for_scheduler(
113
+ model_output: torch.Tensor,
114
+ channels: int,
115
+ scheduler: KarrasDiffusionSchedulers,
116
+ ) -> torch.Tensor:
117
+ if model_output.shape[1] != 2 * channels:
118
+ return model_output
119
+ variance_type = getattr(scheduler.config, "variance_type", None)
120
+ if scheduler.__class__.__name__ == "DDPMScheduler" and variance_type in ("learned", "learned_range"):
121
+ return model_output
122
+ model_output, _ = torch.split(model_output, channels, dim=1)
123
+ return model_output
124
+
125
  @staticmethod
126
  def _expand_timestep(timestep, batch: int, device: torch.device) -> torch.Tensor:
127
  if not torch.is_tensor(timestep):
 
256
  "Use a DDPM/DDIM-compatible scheduler or disable classifier guidance."
257
  )
258
 
259
+ step_model_output = self._prepare_model_output_for_scheduler(step_model_output, channels, scheduler)
260
  latents = scheduler.step(step_model_output, timestep, latents, return_dict=True, **extra_step_kwargs).prev_sample
261
 
262
  image = latents if output_type == "latent" else (latents / 2 + 0.5).clamp(0, 1)
ADM-G-512/pipeline.py CHANGED
@@ -108,6 +108,20 @@ class ADMPipeline(DiffusionPipeline):
108
  def _is_ddim_like(step_params: Set[str]) -> bool:
109
  return "eta" in step_params
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  @staticmethod
112
  def _expand_timestep(timestep, batch: int, device: torch.device) -> torch.Tensor:
113
  if not torch.is_tensor(timestep):
@@ -242,6 +256,7 @@ class ADMPipeline(DiffusionPipeline):
242
  "Use a DDPM/DDIM-compatible scheduler or disable classifier guidance."
243
  )
244
 
 
245
  latents = scheduler.step(step_model_output, timestep, latents, return_dict=True, **extra_step_kwargs).prev_sample
246
 
247
  image = latents if output_type == "latent" else (latents / 2 + 0.5).clamp(0, 1)
 
108
  def _is_ddim_like(step_params: Set[str]) -> bool:
109
  return "eta" in step_params
110
 
111
+ @staticmethod
112
+ def _prepare_model_output_for_scheduler(
113
+ model_output: torch.Tensor,
114
+ channels: int,
115
+ scheduler: KarrasDiffusionSchedulers,
116
+ ) -> torch.Tensor:
117
+ if model_output.shape[1] != 2 * channels:
118
+ return model_output
119
+ variance_type = getattr(scheduler.config, "variance_type", None)
120
+ if scheduler.__class__.__name__ == "DDPMScheduler" and variance_type in ("learned", "learned_range"):
121
+ return model_output
122
+ model_output, _ = torch.split(model_output, channels, dim=1)
123
+ return model_output
124
+
125
  @staticmethod
126
  def _expand_timestep(timestep, batch: int, device: torch.device) -> torch.Tensor:
127
  if not torch.is_tensor(timestep):
 
256
  "Use a DDPM/DDIM-compatible scheduler or disable classifier guidance."
257
  )
258
 
259
+ step_model_output = self._prepare_model_output_for_scheduler(step_model_output, channels, scheduler)
260
  latents = scheduler.step(step_model_output, timestep, latents, return_dict=True, **extra_step_kwargs).prev_sample
261
 
262
  image = latents if output_type == "latent" else (latents / 2 + 0.5).clamp(0, 1)