Luffuly commited on
Commit
6c1281c
·
1 Parent(s): 5189c0a

rm old pipline

Browse files
Files changed (1) hide show
  1. pipeline_img2mvimg.py +0 -290
pipeline_img2mvimg.py DELETED
@@ -1,290 +0,0 @@
1
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2
-
3
- import numpy as np
4
- import torch
5
-
6
- from diffusers import AutoencoderKL, StableDiffusionImageVariationPipeline, UNet2DConditionModel
7
- from diffusers.schedulers import KarrasDiffusionSchedulers, DDPMScheduler
8
- from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
9
- from PIL import Image
10
- from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
11
- from omegaconf import DictConfig, OmegaConf
12
-
13
-
14
-
15
- class StableDiffusionImage2MVCustomPipeline(
16
- StableDiffusionImageVariationPipeline
17
- ):
18
- def __init__(
19
- self,
20
- vae: AutoencoderKL,
21
- image_encoder: CLIPVisionModelWithProjection,
22
- unet: UNet2DConditionModel,
23
- scheduler: KarrasDiffusionSchedulers,
24
- safety_checker: StableDiffusionSafetyChecker,
25
- feature_extractor: CLIPImageProcessor,
26
- requires_safety_checker: bool = True,
27
- latents_offset=None,
28
- noisy_cond_latents=False,
29
- condition_offset=True,
30
- ):
31
- super().__init__(
32
- vae=vae,
33
- image_encoder=image_encoder,
34
- unet=unet,
35
- scheduler=scheduler,
36
- safety_checker=safety_checker,
37
- feature_extractor=feature_extractor,
38
- requires_safety_checker=requires_safety_checker
39
- )
40
- latents_offset = tuple(latents_offset) if latents_offset is not None else None
41
- self.latents_offset = latents_offset
42
- if latents_offset is not None:
43
- self.register_to_config(latents_offset=latents_offset)
44
- if noisy_cond_latents:
45
- raise NotImplementedError("Noisy condition latents not supported Now.")
46
- self.condition_offset = condition_offset
47
- self.register_to_config(condition_offset=condition_offset)
48
-
49
- def encode_latents(self, image: Image.Image, device, dtype, height, width):
50
- images = self.image_processor.preprocess(image.convert("RGB"), height=height, width=width).to(device, dtype=dtype)
51
- # NOTE: .mode() for condition
52
- latents = self.vae.encode(images).latent_dist.mode() * self.vae.config.scaling_factor
53
- if self.latents_offset is not None and self.condition_offset:
54
- return latents - torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
55
- else:
56
- return latents
57
-
58
- def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
59
- dtype = next(self.image_encoder.parameters()).dtype
60
-
61
- if not isinstance(image, torch.Tensor):
62
- image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
63
-
64
- image = image.to(device=device, dtype=dtype)
65
- image_embeddings = self.image_encoder(image).image_embeds
66
- image_embeddings = image_embeddings.unsqueeze(1)
67
-
68
- # duplicate image embeddings for each generation per prompt, using mps friendly method
69
- bs_embed, seq_len, _ = image_embeddings.shape
70
- image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
71
- image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
72
-
73
- if do_classifier_free_guidance:
74
- # NOTE: the same as original code
75
- negative_prompt_embeds = torch.zeros_like(image_embeddings)
76
- # For classifier free guidance, we need to do two forward passes.
77
- # Here we concatenate the unconditional and text embeddings into a single batch
78
- # to avoid doing two forward passes
79
- image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
80
-
81
- return image_embeddings
82
-
83
- @torch.no_grad()
84
- def __call__(
85
- self,
86
- image: Union[Image.Image, List[Image.Image], torch.FloatTensor],
87
- height: Optional[int] = 1024,
88
- width: Optional[int] = 1024,
89
- height_cond: Optional[int] = 512,
90
- width_cond: Optional[int] = 512,
91
- num_inference_steps: int = 50,
92
- guidance_scale: float = 7.5,
93
- num_images_per_prompt: Optional[int] = 1,
94
- eta: float = 0.0,
95
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
96
- latents: Optional[torch.FloatTensor] = None,
97
- output_type: Optional[str] = "pil",
98
- return_dict: bool = True,
99
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
100
- callback_steps: int = 1,
101
- strength: float = 0.0,
102
- ):
103
- r"""
104
- The call function to the pipeline for generation.
105
-
106
- Args:
107
- image (`Image.Image` or `List[Image.Image]` or `torch.FloatTensor`):
108
- Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
109
- [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
110
- height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
111
- The height in pixels of the generated image.
112
- width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
113
- The width in pixels of the generated image.
114
- num_inference_steps (`int`, *optional*, defaults to 50):
115
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
116
- expense of slower inference. This parameter is modulated by `strength`.
117
- guidance_scale (`float`, *optional*, defaults to 7.5):
118
- A higher guidance scale value encourages the model to generate images closely linked to the text
119
- `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
120
- num_images_per_prompt (`int`, *optional*, defaults to 1):
121
- The number of images to generate per prompt.
122
- eta (`float`, *optional*, defaults to 0.0):
123
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
124
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
125
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
126
- A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
127
- generation deterministic.
128
- latents (`torch.FloatTensor`, *optional*):
129
- Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
130
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
131
- tensor is generated by sampling using the supplied random `generator`.config.j
132
- output_type (`str`, *optional*, defaults to `"pil"`):
133
- The output format of the generated image. Choose between `PIL.Image` or `np.array`.
134
- return_dict (`bool`, *optional*, defaults to `True`):
135
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
136
- plain tuple.
137
- callback (`Callable`, *optional*):
138
- A function that calls every `callback_steps` steps during inference. The function is called with the
139
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
140
- callback_steps (`int`, *optional*, defaults to 1):
141
- The frequency at which the `callback` function is called. If not specified, the callback is called at
142
- every step.
143
-
144
- Returns:
145
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
146
- If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
147
- otherwise a `tuple` is returned where the first element is a list with the generated images and the
148
- second element is a list of `bool`s indicating whether the corresponding generated image contains
149
- "not-safe-for-work" (nsfw) content.
150
-
151
- Examples:
152
-
153
- ```py
154
- from diffusers import StableDiffusionImageVariationPipeline
155
- from PIL import Image
156
- from io import BytesIO
157
- import requests
158
-
159
- pipe = StableDiffusionImageVariationPipeline.from_pretrained(
160
- "lambdalabs/sd-image-variations-diffusers", revision="v2.0"
161
- )
162
- pipe = pipe.to("cuda")
163
-
164
- url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200"
165
-
166
- response = requests.get(url)
167
- image = Image.open(BytesIO(response.content)).convert("RGB")
168
-
169
- out = pipe(image, num_images_per_prompt=3, guidance_scale=15)
170
- out["images"][0].save("result.jpg")
171
- ```
172
- """
173
- # 0. Default height and width to unet
174
- height = height or self.unet.config.sample_size * self.vae_scale_factor
175
- width = width or self.unet.config.sample_size * self.vae_scale_factor
176
-
177
- # 1. Check inputs. Raise error if not correct
178
- self.check_inputs(image, height, width, callback_steps)
179
-
180
- # 2. Define call parameters
181
- if isinstance(image, Image.Image):
182
- batch_size = 1
183
- elif len(image) == 1:
184
- image = image[0]
185
- batch_size = 1
186
- else:
187
- raise NotImplementedError()
188
- # elif isinstance(image, list):
189
- # batch_size = len(image)
190
- # else:
191
- # batch_size = image.shape[0]
192
- device = self._execution_device
193
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
194
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
195
- # corresponds to doing no classifier free guidance.
196
- do_classifier_free_guidance = guidance_scale > 1.0
197
-
198
- # 3. Encode input image
199
- emb_image = image
200
-
201
- image_embeddings = self._encode_image(emb_image, device, num_images_per_prompt, do_classifier_free_guidance)
202
- cond_latents = self.encode_latents(image, image_embeddings.device, image_embeddings.dtype, height_cond, width_cond)
203
- cond_latents = torch.cat([torch.zeros_like(cond_latents), cond_latents]) if do_classifier_free_guidance else cond_latents
204
- image_pixels = self.feature_extractor(images=emb_image, return_tensors="pt").pixel_values
205
- if do_classifier_free_guidance:
206
- image_pixels = torch.cat([torch.zeros_like(image_pixels), image_pixels], dim=0)
207
-
208
- # 4. Prepare timesteps
209
- self.scheduler.set_timesteps(num_inference_steps, device=device)
210
- timesteps = self.scheduler.timesteps
211
-
212
- # 5. Prepare latent variables
213
- num_channels_latents = self.unet.config.out_channels
214
- latents = self.prepare_latents(
215
- batch_size * num_images_per_prompt,
216
- num_channels_latents,
217
- height,
218
- width,
219
- image_embeddings.dtype,
220
- device,
221
- generator,
222
- latents,
223
- )
224
-
225
-
226
- # 6. Prepare extra step kwargs.
227
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
228
- # 7. Denoising loop
229
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
230
- with self.progress_bar(total=num_inference_steps) as progress_bar:
231
-
232
- class_labels = torch.tensor([0, 1, 2, 3, 0, 1, 2, 3]).to(latents.device)
233
-
234
- for i, t in enumerate(timesteps[int(len(timesteps)*strength):]):
235
- # expand the latents if we are doing classifier free guidance
236
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
237
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
238
-
239
- num_repeat = len(latent_model_input) // len(cond_latents) # [2, 4, 32, 32]
240
- cat_latents = torch.stack([cond_latents] * num_repeat, 1).reshape(*latent_model_input.shape)
241
-
242
- sample = torch.cat([latent_model_input, cat_latents], dim=1)
243
- noise_pred = self.unet(
244
- sample,
245
- t,
246
- image_embeddings,
247
- cross_attention_kwargs=None,
248
- class_labels=class_labels,
249
- ).sample
250
- # perform guidance
251
- if do_classifier_free_guidance:
252
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
253
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
254
-
255
- # compute the previous noisy sample x_t -> x_t-1
256
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
257
-
258
- # call the callback, if provided
259
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
260
- progress_bar.update()
261
- if callback is not None and i % callback_steps == 0:
262
- step_idx = i // getattr(self.scheduler, "order", 1)
263
- callback(step_idx, t, latents)
264
-
265
- self.maybe_free_model_hooks()
266
-
267
- if self.latents_offset is not None:
268
- latents = latents + torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
269
-
270
- if not output_type == "latent":
271
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
272
- image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
273
- else:
274
- image = latents
275
- has_nsfw_concept = None
276
-
277
- if has_nsfw_concept is None:
278
- do_denormalize = [True] * image.shape[0]
279
- else:
280
- do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
281
-
282
- image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
283
-
284
- self.maybe_free_model_hooks()
285
-
286
- if not return_dict:
287
- return (image, has_nsfw_concept)
288
-
289
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
290
-