Luffuly commited on
Commit
52918a8
·
1 Parent(s): 0947f8b

add pipeline

Browse files
Files changed (1) hide show
  1. pipeline_img2mvimg.py +290 -0
pipeline_img2mvimg.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from diffusers import AutoencoderKL, StableDiffusionImageVariationPipeline, UNet2DConditionModel
7
+ from diffusers.schedulers import KarrasDiffusionSchedulers, DDPMScheduler
8
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
9
+ from PIL import Image
10
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
11
+ from omegaconf import DictConfig, OmegaConf
12
+
13
+
14
+
15
+ class StableDiffusionImage2MVCustomPipeline(
16
+ StableDiffusionImageVariationPipeline
17
+ ):
18
+ def __init__(
19
+ self,
20
+ vae: AutoencoderKL,
21
+ image_encoder: CLIPVisionModelWithProjection,
22
+ unet: UNet2DConditionModel,
23
+ scheduler: KarrasDiffusionSchedulers,
24
+ safety_checker: StableDiffusionSafetyChecker,
25
+ feature_extractor: CLIPImageProcessor,
26
+ requires_safety_checker: bool = True,
27
+ latents_offset=None,
28
+ noisy_cond_latents=False,
29
+ condition_offset=True,
30
+ ):
31
+ super().__init__(
32
+ vae=vae,
33
+ image_encoder=image_encoder,
34
+ unet=unet,
35
+ scheduler=scheduler,
36
+ safety_checker=safety_checker,
37
+ feature_extractor=feature_extractor,
38
+ requires_safety_checker=requires_safety_checker
39
+ )
40
+ latents_offset = tuple(latents_offset) if latents_offset is not None else None
41
+ self.latents_offset = latents_offset
42
+ if latents_offset is not None:
43
+ self.register_to_config(latents_offset=latents_offset)
44
+ if noisy_cond_latents:
45
+ raise NotImplementedError("Noisy condition latents not supported Now.")
46
+ self.condition_offset = condition_offset
47
+ self.register_to_config(condition_offset=condition_offset)
48
+
49
+ def encode_latents(self, image: Image.Image, device, dtype, height, width):
50
+ images = self.image_processor.preprocess(image.convert("RGB"), height=height, width=width).to(device, dtype=dtype)
51
+ # NOTE: .mode() for condition
52
+ latents = self.vae.encode(images).latent_dist.mode() * self.vae.config.scaling_factor
53
+ if self.latents_offset is not None and self.condition_offset:
54
+ return latents - torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
55
+ else:
56
+ return latents
57
+
58
+ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
59
+ dtype = next(self.image_encoder.parameters()).dtype
60
+
61
+ if not isinstance(image, torch.Tensor):
62
+ image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
63
+
64
+ image = image.to(device=device, dtype=dtype)
65
+ image_embeddings = self.image_encoder(image).image_embeds
66
+ image_embeddings = image_embeddings.unsqueeze(1)
67
+
68
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
69
+ bs_embed, seq_len, _ = image_embeddings.shape
70
+ image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
71
+ image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
72
+
73
+ if do_classifier_free_guidance:
74
+ # NOTE: the same as original code
75
+ negative_prompt_embeds = torch.zeros_like(image_embeddings)
76
+ # For classifier free guidance, we need to do two forward passes.
77
+ # Here we concatenate the unconditional and text embeddings into a single batch
78
+ # to avoid doing two forward passes
79
+ image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
80
+
81
+ return image_embeddings
82
+
83
+ @torch.no_grad()
84
+ def __call__(
85
+ self,
86
+ image: Union[Image.Image, List[Image.Image], torch.FloatTensor],
87
+ height: Optional[int] = 1024,
88
+ width: Optional[int] = 1024,
89
+ height_cond: Optional[int] = 512,
90
+ width_cond: Optional[int] = 512,
91
+ num_inference_steps: int = 50,
92
+ guidance_scale: float = 7.5,
93
+ num_images_per_prompt: Optional[int] = 1,
94
+ eta: float = 0.0,
95
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
96
+ latents: Optional[torch.FloatTensor] = None,
97
+ output_type: Optional[str] = "pil",
98
+ return_dict: bool = True,
99
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
100
+ callback_steps: int = 1,
101
+ strength: float = 0.0,
102
+ ):
103
+ r"""
104
+ The call function to the pipeline for generation.
105
+
106
+ Args:
107
+ image (`Image.Image` or `List[Image.Image]` or `torch.FloatTensor`):
108
+ Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
109
+ [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
110
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
111
+ The height in pixels of the generated image.
112
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
113
+ The width in pixels of the generated image.
114
+ num_inference_steps (`int`, *optional*, defaults to 50):
115
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
116
+ expense of slower inference. This parameter is modulated by `strength`.
117
+ guidance_scale (`float`, *optional*, defaults to 7.5):
118
+ A higher guidance scale value encourages the model to generate images closely linked to the text
119
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
120
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
121
+ The number of images to generate per prompt.
122
+ eta (`float`, *optional*, defaults to 0.0):
123
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
124
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
125
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
126
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
127
+ generation deterministic.
128
+ latents (`torch.FloatTensor`, *optional*):
129
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
130
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
131
+ tensor is generated by sampling using the supplied random `generator`.config.j
132
+ output_type (`str`, *optional*, defaults to `"pil"`):
133
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
134
+ return_dict (`bool`, *optional*, defaults to `True`):
135
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
136
+ plain tuple.
137
+ callback (`Callable`, *optional*):
138
+ A function that calls every `callback_steps` steps during inference. The function is called with the
139
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
140
+ callback_steps (`int`, *optional*, defaults to 1):
141
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
142
+ every step.
143
+
144
+ Returns:
145
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
146
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
147
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
148
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
149
+ "not-safe-for-work" (nsfw) content.
150
+
151
+ Examples:
152
+
153
+ ```py
154
+ from diffusers import StableDiffusionImageVariationPipeline
155
+ from PIL import Image
156
+ from io import BytesIO
157
+ import requests
158
+
159
+ pipe = StableDiffusionImageVariationPipeline.from_pretrained(
160
+ "lambdalabs/sd-image-variations-diffusers", revision="v2.0"
161
+ )
162
+ pipe = pipe.to("cuda")
163
+
164
+ url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200"
165
+
166
+ response = requests.get(url)
167
+ image = Image.open(BytesIO(response.content)).convert("RGB")
168
+
169
+ out = pipe(image, num_images_per_prompt=3, guidance_scale=15)
170
+ out["images"][0].save("result.jpg")
171
+ ```
172
+ """
173
+ # 0. Default height and width to unet
174
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
175
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
176
+
177
+ # 1. Check inputs. Raise error if not correct
178
+ self.check_inputs(image, height, width, callback_steps)
179
+
180
+ # 2. Define call parameters
181
+ if isinstance(image, Image.Image):
182
+ batch_size = 1
183
+ elif len(image) == 1:
184
+ image = image[0]
185
+ batch_size = 1
186
+ else:
187
+ raise NotImplementedError()
188
+ # elif isinstance(image, list):
189
+ # batch_size = len(image)
190
+ # else:
191
+ # batch_size = image.shape[0]
192
+ device = self._execution_device
193
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
194
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
195
+ # corresponds to doing no classifier free guidance.
196
+ do_classifier_free_guidance = guidance_scale > 1.0
197
+
198
+ # 3. Encode input image
199
+ emb_image = image
200
+
201
+ image_embeddings = self._encode_image(emb_image, device, num_images_per_prompt, do_classifier_free_guidance)
202
+ cond_latents = self.encode_latents(image, image_embeddings.device, image_embeddings.dtype, height_cond, width_cond)
203
+ cond_latents = torch.cat([torch.zeros_like(cond_latents), cond_latents]) if do_classifier_free_guidance else cond_latents
204
+ image_pixels = self.feature_extractor(images=emb_image, return_tensors="pt").pixel_values
205
+ if do_classifier_free_guidance:
206
+ image_pixels = torch.cat([torch.zeros_like(image_pixels), image_pixels], dim=0)
207
+
208
+ # 4. Prepare timesteps
209
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
210
+ timesteps = self.scheduler.timesteps
211
+
212
+ # 5. Prepare latent variables
213
+ num_channels_latents = self.unet.config.out_channels
214
+ latents = self.prepare_latents(
215
+ batch_size * num_images_per_prompt,
216
+ num_channels_latents,
217
+ height,
218
+ width,
219
+ image_embeddings.dtype,
220
+ device,
221
+ generator,
222
+ latents,
223
+ )
224
+
225
+
226
+ # 6. Prepare extra step kwargs.
227
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
228
+ # 7. Denoising loop
229
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
230
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
231
+
232
+ class_labels = torch.tensor([0, 1, 2, 3, 0, 1, 2, 3]).to(latents.device)
233
+
234
+ for i, t in enumerate(timesteps[int(len(timesteps)*strength):]):
235
+ # expand the latents if we are doing classifier free guidance
236
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
237
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
238
+
239
+ num_repeat = len(latent_model_input) // len(cond_latents) # [2, 4, 32, 32]
240
+ cat_latents = torch.stack([cond_latents] * num_repeat, 1).reshape(*latent_model_input.shape)
241
+
242
+ sample = torch.cat([latent_model_input, cat_latents], dim=1)
243
+ noise_pred = self.unet(
244
+ sample,
245
+ t,
246
+ image_embeddings,
247
+ cross_attention_kwargs=None,
248
+ class_labels=class_labels,
249
+ ).sample
250
+ # perform guidance
251
+ if do_classifier_free_guidance:
252
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
253
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
254
+
255
+ # compute the previous noisy sample x_t -> x_t-1
256
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
257
+
258
+ # call the callback, if provided
259
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
260
+ progress_bar.update()
261
+ if callback is not None and i % callback_steps == 0:
262
+ step_idx = i // getattr(self.scheduler, "order", 1)
263
+ callback(step_idx, t, latents)
264
+
265
+ self.maybe_free_model_hooks()
266
+
267
+ if self.latents_offset is not None:
268
+ latents = latents + torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
269
+
270
+ if not output_type == "latent":
271
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
272
+ image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
273
+ else:
274
+ image = latents
275
+ has_nsfw_concept = None
276
+
277
+ if has_nsfw_concept is None:
278
+ do_denormalize = [True] * image.shape[0]
279
+ else:
280
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
281
+
282
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
283
+
284
+ self.maybe_free_model_hooks()
285
+
286
+ if not return_dict:
287
+ return (image, has_nsfw_concept)
288
+
289
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
290
+