Upload cond_ddim_pipeline.py
Browse files- cond_ddim_pipeline.py +123 -0
cond_ddim_pipeline.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Tuple, Union
|
| 2 |
+
import torch
|
| 3 |
+
import inspect
|
| 4 |
+
from diffusers import DDIMScheduler, DiffusionPipeline, ImagePipelineOutput
|
| 5 |
+
|
| 6 |
+
class CondDDIMPipeline(DiffusionPipeline):
|
| 7 |
+
r"""
|
| 8 |
+
Pipeline for image generation.
|
| 9 |
+
|
| 10 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 11 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 12 |
+
|
| 13 |
+
Parameters:
|
| 14 |
+
unet ([`UNet2DModel`]):
|
| 15 |
+
A `UNet2DModel` to denoise the encoded image latents.
|
| 16 |
+
scheduler ([`SchedulerMixin`]):
|
| 17 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
|
| 18 |
+
[`DDPMScheduler`], or [`DDIMScheduler`].
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
model_cpu_offload_seq = "unet"
|
| 22 |
+
|
| 23 |
+
def __init__(self, unet, scheduler):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
scheduler = DDIMScheduler.from_config(scheduler.config)
|
| 27 |
+
|
| 28 |
+
self.register_modules(unet=unet, scheduler=scheduler)
|
| 29 |
+
|
| 30 |
+
@torch.no_grad()
|
| 31 |
+
def __call__(
|
| 32 |
+
self,
|
| 33 |
+
batch_size: int = 1,
|
| 34 |
+
image: torch.Tensor = None,
|
| 35 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 36 |
+
num_images_per_cond: Optional[int] = 1,
|
| 37 |
+
eta: float = 0.0,
|
| 38 |
+
num_inference_steps: int = 50,
|
| 39 |
+
use_clipped_model_output: Optional[bool] = None,
|
| 40 |
+
output_type: Optional[str] = "pil",
|
| 41 |
+
return_dict: bool = True,
|
| 42 |
+
) -> Union[ImagePipelineOutput, Tuple]:
|
| 43 |
+
r"""
|
| 44 |
+
The call function to the pipeline for generation.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
batch_size (`int`, *optional*, defaults to 1):
|
| 48 |
+
The number of images to generate.
|
| 49 |
+
image (torch.Tensor):
|
| 50 |
+
The LR image(s) to condition on.
|
| 51 |
+
generator (`torch.Generator`, *optional*):
|
| 52 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 53 |
+
generation deterministic.
|
| 54 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 55 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
| 56 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. A value of `0` corresponds to
|
| 57 |
+
DDIM and `1` corresponds to DDPM.
|
| 58 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 59 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 60 |
+
expense of slower inference.
|
| 61 |
+
use_clipped_model_output (`bool`, *optional*, defaults to `None`):
|
| 62 |
+
If `True` or `False`, see documentation for [`DDIMScheduler.step`]. If `None`, nothing is passed
|
| 63 |
+
downstream to the scheduler (use `None` for schedulers which don't support this argument).
|
| 64 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 65 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 66 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 67 |
+
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
# create random tensor of shape latents
|
| 71 |
+
bs, _, height, width = image.shape
|
| 72 |
+
|
| 73 |
+
# check that generator is on device cuda
|
| 74 |
+
generator = torch.Generator(device=self._execution_device)
|
| 75 |
+
|
| 76 |
+
latents_shape = (bs * num_images_per_cond, self.unet.config.out_channels, height, width)
|
| 77 |
+
|
| 78 |
+
latents = torch.randn(latents_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype)
|
| 79 |
+
latents_dtype = next(self.unet.parameters()).dtype
|
| 80 |
+
|
| 81 |
+
# bring conditional img to device
|
| 82 |
+
image = torch.cat([image] * num_images_per_cond)
|
| 83 |
+
image = image.to(device=self.device, dtype=latents_dtype)
|
| 84 |
+
|
| 85 |
+
# set step values
|
| 86 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
| 87 |
+
|
| 88 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 89 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 90 |
+
|
| 91 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature.
|
| 92 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 93 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 94 |
+
# and should be between [0, 1]
|
| 95 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 96 |
+
extra_kwargs = {}
|
| 97 |
+
if accepts_eta:
|
| 98 |
+
extra_kwargs["eta"] = eta
|
| 99 |
+
|
| 100 |
+
for t in self.progress_bar(self.scheduler.timesteps):
|
| 101 |
+
|
| 102 |
+
# 1. predict noise model_output
|
| 103 |
+
latents_input = torch.cat([latents, image], dim=1)
|
| 104 |
+
latents_input = self.scheduler.scale_model_input(latents_input, t)
|
| 105 |
+
|
| 106 |
+
noise_pred = self.unet(latents_input, t).sample
|
| 107 |
+
|
| 108 |
+
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
| 109 |
+
# eta corresponds to η in paper and should be between [0, 1]
|
| 110 |
+
# do x_t -> x_t-1
|
| 111 |
+
latents = self.scheduler.step(
|
| 112 |
+
noise_pred, t, latents, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator
|
| 113 |
+
).prev_sample
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
image = latents.cpu().numpy()
|
| 117 |
+
if output_type == "pil":
|
| 118 |
+
image = self.numpy_to_pil(image)
|
| 119 |
+
|
| 120 |
+
if not return_dict:
|
| 121 |
+
return (image,)
|
| 122 |
+
|
| 123 |
+
return ImagePipelineOutput(images=image)
|