lschmidt commited on
Commit
b821298
·
verified ·
1 Parent(s): 92f0a1a

Upload cond_ddim_pipeline.py

Browse files
Files changed (1) hide show
  1. cond_ddim_pipeline.py +123 -0
cond_ddim_pipeline.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+ import torch
3
+ import inspect
4
+ from diffusers import DDIMScheduler, DiffusionPipeline, ImagePipelineOutput
5
+
6
+ class CondDDIMPipeline(DiffusionPipeline):
7
+ r"""
8
+ Pipeline for image generation.
9
+
10
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
11
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
12
+
13
+ Parameters:
14
+ unet ([`UNet2DModel`]):
15
+ A `UNet2DModel` to denoise the encoded image latents.
16
+ scheduler ([`SchedulerMixin`]):
17
+ A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
18
+ [`DDPMScheduler`], or [`DDIMScheduler`].
19
+ """
20
+
21
+ model_cpu_offload_seq = "unet"
22
+
23
+ def __init__(self, unet, scheduler):
24
+ super().__init__()
25
+
26
+ scheduler = DDIMScheduler.from_config(scheduler.config)
27
+
28
+ self.register_modules(unet=unet, scheduler=scheduler)
29
+
30
+ @torch.no_grad()
31
+ def __call__(
32
+ self,
33
+ batch_size: int = 1,
34
+ image: torch.Tensor = None,
35
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
36
+ num_images_per_cond: Optional[int] = 1,
37
+ eta: float = 0.0,
38
+ num_inference_steps: int = 50,
39
+ use_clipped_model_output: Optional[bool] = None,
40
+ output_type: Optional[str] = "pil",
41
+ return_dict: bool = True,
42
+ ) -> Union[ImagePipelineOutput, Tuple]:
43
+ r"""
44
+ The call function to the pipeline for generation.
45
+
46
+ Args:
47
+ batch_size (`int`, *optional*, defaults to 1):
48
+ The number of images to generate.
49
+ image (torch.Tensor):
50
+ The LR image(s) to condition on.
51
+ generator (`torch.Generator`, *optional*):
52
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
53
+ generation deterministic.
54
+ eta (`float`, *optional*, defaults to 0.0):
55
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
56
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. A value of `0` corresponds to
57
+ DDIM and `1` corresponds to DDPM.
58
+ num_inference_steps (`int`, *optional*, defaults to 50):
59
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
60
+ expense of slower inference.
61
+ use_clipped_model_output (`bool`, *optional*, defaults to `None`):
62
+ If `True` or `False`, see documentation for [`DDIMScheduler.step`]. If `None`, nothing is passed
63
+ downstream to the scheduler (use `None` for schedulers which don't support this argument).
64
+ output_type (`str`, *optional*, defaults to `"pil"`):
65
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
66
+ return_dict (`bool`, *optional*, defaults to `True`):
67
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
68
+ """
69
+
70
+ # create random tensor of shape latents
71
+ bs, _, height, width = image.shape
72
+
73
+ # check that generator is on device cuda
74
+ generator = torch.Generator(device=self._execution_device)
75
+
76
+ latents_shape = (bs * num_images_per_cond, self.unet.config.out_channels, height, width)
77
+
78
+ latents = torch.randn(latents_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype)
79
+ latents_dtype = next(self.unet.parameters()).dtype
80
+
81
+ # bring conditional img to device
82
+ image = torch.cat([image] * num_images_per_cond)
83
+ image = image.to(device=self.device, dtype=latents_dtype)
84
+
85
+ # set step values
86
+ self.scheduler.set_timesteps(num_inference_steps)
87
+
88
+ # scale the initial noise by the standard deviation required by the scheduler
89
+ latents = latents * self.scheduler.init_noise_sigma
90
+
91
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature.
92
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
93
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
94
+ # and should be between [0, 1]
95
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
96
+ extra_kwargs = {}
97
+ if accepts_eta:
98
+ extra_kwargs["eta"] = eta
99
+
100
+ for t in self.progress_bar(self.scheduler.timesteps):
101
+
102
+ # 1. predict noise model_output
103
+ latents_input = torch.cat([latents, image], dim=1)
104
+ latents_input = self.scheduler.scale_model_input(latents_input, t)
105
+
106
+ noise_pred = self.unet(latents_input, t).sample
107
+
108
+ # 2. predict previous mean of image x_t-1 and add variance depending on eta
109
+ # eta corresponds to η in paper and should be between [0, 1]
110
+ # do x_t -> x_t-1
111
+ latents = self.scheduler.step(
112
+ noise_pred, t, latents, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator
113
+ ).prev_sample
114
+
115
+
116
+ image = latents.cpu().numpy()
117
+ if output_type == "pil":
118
+ image = self.numpy_to_pil(image)
119
+
120
+ if not return_dict:
121
+ return (image,)
122
+
123
+ return ImagePipelineOutput(images=image)