rm old pipline
Browse files- pipeline_img2mvimg.py +0 -290
pipeline_img2mvimg.py
DELETED
|
@@ -1,290 +0,0 @@
|
|
| 1 |
-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
|
| 6 |
-
from diffusers import AutoencoderKL, StableDiffusionImageVariationPipeline, UNet2DConditionModel
|
| 7 |
-
from diffusers.schedulers import KarrasDiffusionSchedulers, DDPMScheduler
|
| 8 |
-
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
|
| 9 |
-
from PIL import Image
|
| 10 |
-
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
| 11 |
-
from omegaconf import DictConfig, OmegaConf
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
class StableDiffusionImage2MVCustomPipeline(
|
| 16 |
-
StableDiffusionImageVariationPipeline
|
| 17 |
-
):
|
| 18 |
-
def __init__(
|
| 19 |
-
self,
|
| 20 |
-
vae: AutoencoderKL,
|
| 21 |
-
image_encoder: CLIPVisionModelWithProjection,
|
| 22 |
-
unet: UNet2DConditionModel,
|
| 23 |
-
scheduler: KarrasDiffusionSchedulers,
|
| 24 |
-
safety_checker: StableDiffusionSafetyChecker,
|
| 25 |
-
feature_extractor: CLIPImageProcessor,
|
| 26 |
-
requires_safety_checker: bool = True,
|
| 27 |
-
latents_offset=None,
|
| 28 |
-
noisy_cond_latents=False,
|
| 29 |
-
condition_offset=True,
|
| 30 |
-
):
|
| 31 |
-
super().__init__(
|
| 32 |
-
vae=vae,
|
| 33 |
-
image_encoder=image_encoder,
|
| 34 |
-
unet=unet,
|
| 35 |
-
scheduler=scheduler,
|
| 36 |
-
safety_checker=safety_checker,
|
| 37 |
-
feature_extractor=feature_extractor,
|
| 38 |
-
requires_safety_checker=requires_safety_checker
|
| 39 |
-
)
|
| 40 |
-
latents_offset = tuple(latents_offset) if latents_offset is not None else None
|
| 41 |
-
self.latents_offset = latents_offset
|
| 42 |
-
if latents_offset is not None:
|
| 43 |
-
self.register_to_config(latents_offset=latents_offset)
|
| 44 |
-
if noisy_cond_latents:
|
| 45 |
-
raise NotImplementedError("Noisy condition latents not supported Now.")
|
| 46 |
-
self.condition_offset = condition_offset
|
| 47 |
-
self.register_to_config(condition_offset=condition_offset)
|
| 48 |
-
|
| 49 |
-
def encode_latents(self, image: Image.Image, device, dtype, height, width):
|
| 50 |
-
images = self.image_processor.preprocess(image.convert("RGB"), height=height, width=width).to(device, dtype=dtype)
|
| 51 |
-
# NOTE: .mode() for condition
|
| 52 |
-
latents = self.vae.encode(images).latent_dist.mode() * self.vae.config.scaling_factor
|
| 53 |
-
if self.latents_offset is not None and self.condition_offset:
|
| 54 |
-
return latents - torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
|
| 55 |
-
else:
|
| 56 |
-
return latents
|
| 57 |
-
|
| 58 |
-
def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
|
| 59 |
-
dtype = next(self.image_encoder.parameters()).dtype
|
| 60 |
-
|
| 61 |
-
if not isinstance(image, torch.Tensor):
|
| 62 |
-
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
|
| 63 |
-
|
| 64 |
-
image = image.to(device=device, dtype=dtype)
|
| 65 |
-
image_embeddings = self.image_encoder(image).image_embeds
|
| 66 |
-
image_embeddings = image_embeddings.unsqueeze(1)
|
| 67 |
-
|
| 68 |
-
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
| 69 |
-
bs_embed, seq_len, _ = image_embeddings.shape
|
| 70 |
-
image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
|
| 71 |
-
image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
| 72 |
-
|
| 73 |
-
if do_classifier_free_guidance:
|
| 74 |
-
# NOTE: the same as original code
|
| 75 |
-
negative_prompt_embeds = torch.zeros_like(image_embeddings)
|
| 76 |
-
# For classifier free guidance, we need to do two forward passes.
|
| 77 |
-
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 78 |
-
# to avoid doing two forward passes
|
| 79 |
-
image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
|
| 80 |
-
|
| 81 |
-
return image_embeddings
|
| 82 |
-
|
| 83 |
-
@torch.no_grad()
|
| 84 |
-
def __call__(
|
| 85 |
-
self,
|
| 86 |
-
image: Union[Image.Image, List[Image.Image], torch.FloatTensor],
|
| 87 |
-
height: Optional[int] = 1024,
|
| 88 |
-
width: Optional[int] = 1024,
|
| 89 |
-
height_cond: Optional[int] = 512,
|
| 90 |
-
width_cond: Optional[int] = 512,
|
| 91 |
-
num_inference_steps: int = 50,
|
| 92 |
-
guidance_scale: float = 7.5,
|
| 93 |
-
num_images_per_prompt: Optional[int] = 1,
|
| 94 |
-
eta: float = 0.0,
|
| 95 |
-
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 96 |
-
latents: Optional[torch.FloatTensor] = None,
|
| 97 |
-
output_type: Optional[str] = "pil",
|
| 98 |
-
return_dict: bool = True,
|
| 99 |
-
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 100 |
-
callback_steps: int = 1,
|
| 101 |
-
strength: float = 0.0,
|
| 102 |
-
):
|
| 103 |
-
r"""
|
| 104 |
-
The call function to the pipeline for generation.
|
| 105 |
-
|
| 106 |
-
Args:
|
| 107 |
-
image (`Image.Image` or `List[Image.Image]` or `torch.FloatTensor`):
|
| 108 |
-
Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
|
| 109 |
-
[`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
|
| 110 |
-
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 111 |
-
The height in pixels of the generated image.
|
| 112 |
-
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 113 |
-
The width in pixels of the generated image.
|
| 114 |
-
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 115 |
-
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 116 |
-
expense of slower inference. This parameter is modulated by `strength`.
|
| 117 |
-
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 118 |
-
A higher guidance scale value encourages the model to generate images closely linked to the text
|
| 119 |
-
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 120 |
-
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 121 |
-
The number of images to generate per prompt.
|
| 122 |
-
eta (`float`, *optional*, defaults to 0.0):
|
| 123 |
-
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
| 124 |
-
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 125 |
-
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 126 |
-
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 127 |
-
generation deterministic.
|
| 128 |
-
latents (`torch.FloatTensor`, *optional*):
|
| 129 |
-
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 130 |
-
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 131 |
-
tensor is generated by sampling using the supplied random `generator`.config.j
|
| 132 |
-
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 133 |
-
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 134 |
-
return_dict (`bool`, *optional*, defaults to `True`):
|
| 135 |
-
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 136 |
-
plain tuple.
|
| 137 |
-
callback (`Callable`, *optional*):
|
| 138 |
-
A function that calls every `callback_steps` steps during inference. The function is called with the
|
| 139 |
-
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
| 140 |
-
callback_steps (`int`, *optional*, defaults to 1):
|
| 141 |
-
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
| 142 |
-
every step.
|
| 143 |
-
|
| 144 |
-
Returns:
|
| 145 |
-
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 146 |
-
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
| 147 |
-
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
| 148 |
-
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
| 149 |
-
"not-safe-for-work" (nsfw) content.
|
| 150 |
-
|
| 151 |
-
Examples:
|
| 152 |
-
|
| 153 |
-
```py
|
| 154 |
-
from diffusers import StableDiffusionImageVariationPipeline
|
| 155 |
-
from PIL import Image
|
| 156 |
-
from io import BytesIO
|
| 157 |
-
import requests
|
| 158 |
-
|
| 159 |
-
pipe = StableDiffusionImageVariationPipeline.from_pretrained(
|
| 160 |
-
"lambdalabs/sd-image-variations-diffusers", revision="v2.0"
|
| 161 |
-
)
|
| 162 |
-
pipe = pipe.to("cuda")
|
| 163 |
-
|
| 164 |
-
url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200"
|
| 165 |
-
|
| 166 |
-
response = requests.get(url)
|
| 167 |
-
image = Image.open(BytesIO(response.content)).convert("RGB")
|
| 168 |
-
|
| 169 |
-
out = pipe(image, num_images_per_prompt=3, guidance_scale=15)
|
| 170 |
-
out["images"][0].save("result.jpg")
|
| 171 |
-
```
|
| 172 |
-
"""
|
| 173 |
-
# 0. Default height and width to unet
|
| 174 |
-
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 175 |
-
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 176 |
-
|
| 177 |
-
# 1. Check inputs. Raise error if not correct
|
| 178 |
-
self.check_inputs(image, height, width, callback_steps)
|
| 179 |
-
|
| 180 |
-
# 2. Define call parameters
|
| 181 |
-
if isinstance(image, Image.Image):
|
| 182 |
-
batch_size = 1
|
| 183 |
-
elif len(image) == 1:
|
| 184 |
-
image = image[0]
|
| 185 |
-
batch_size = 1
|
| 186 |
-
else:
|
| 187 |
-
raise NotImplementedError()
|
| 188 |
-
# elif isinstance(image, list):
|
| 189 |
-
# batch_size = len(image)
|
| 190 |
-
# else:
|
| 191 |
-
# batch_size = image.shape[0]
|
| 192 |
-
device = self._execution_device
|
| 193 |
-
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 194 |
-
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 195 |
-
# corresponds to doing no classifier free guidance.
|
| 196 |
-
do_classifier_free_guidance = guidance_scale > 1.0
|
| 197 |
-
|
| 198 |
-
# 3. Encode input image
|
| 199 |
-
emb_image = image
|
| 200 |
-
|
| 201 |
-
image_embeddings = self._encode_image(emb_image, device, num_images_per_prompt, do_classifier_free_guidance)
|
| 202 |
-
cond_latents = self.encode_latents(image, image_embeddings.device, image_embeddings.dtype, height_cond, width_cond)
|
| 203 |
-
cond_latents = torch.cat([torch.zeros_like(cond_latents), cond_latents]) if do_classifier_free_guidance else cond_latents
|
| 204 |
-
image_pixels = self.feature_extractor(images=emb_image, return_tensors="pt").pixel_values
|
| 205 |
-
if do_classifier_free_guidance:
|
| 206 |
-
image_pixels = torch.cat([torch.zeros_like(image_pixels), image_pixels], dim=0)
|
| 207 |
-
|
| 208 |
-
# 4. Prepare timesteps
|
| 209 |
-
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 210 |
-
timesteps = self.scheduler.timesteps
|
| 211 |
-
|
| 212 |
-
# 5. Prepare latent variables
|
| 213 |
-
num_channels_latents = self.unet.config.out_channels
|
| 214 |
-
latents = self.prepare_latents(
|
| 215 |
-
batch_size * num_images_per_prompt,
|
| 216 |
-
num_channels_latents,
|
| 217 |
-
height,
|
| 218 |
-
width,
|
| 219 |
-
image_embeddings.dtype,
|
| 220 |
-
device,
|
| 221 |
-
generator,
|
| 222 |
-
latents,
|
| 223 |
-
)
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
# 6. Prepare extra step kwargs.
|
| 227 |
-
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 228 |
-
# 7. Denoising loop
|
| 229 |
-
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 230 |
-
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 231 |
-
|
| 232 |
-
class_labels = torch.tensor([0, 1, 2, 3, 0, 1, 2, 3]).to(latents.device)
|
| 233 |
-
|
| 234 |
-
for i, t in enumerate(timesteps[int(len(timesteps)*strength):]):
|
| 235 |
-
# expand the latents if we are doing classifier free guidance
|
| 236 |
-
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 237 |
-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 238 |
-
|
| 239 |
-
num_repeat = len(latent_model_input) // len(cond_latents) # [2, 4, 32, 32]
|
| 240 |
-
cat_latents = torch.stack([cond_latents] * num_repeat, 1).reshape(*latent_model_input.shape)
|
| 241 |
-
|
| 242 |
-
sample = torch.cat([latent_model_input, cat_latents], dim=1)
|
| 243 |
-
noise_pred = self.unet(
|
| 244 |
-
sample,
|
| 245 |
-
t,
|
| 246 |
-
image_embeddings,
|
| 247 |
-
cross_attention_kwargs=None,
|
| 248 |
-
class_labels=class_labels,
|
| 249 |
-
).sample
|
| 250 |
-
# perform guidance
|
| 251 |
-
if do_classifier_free_guidance:
|
| 252 |
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 253 |
-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 254 |
-
|
| 255 |
-
# compute the previous noisy sample x_t -> x_t-1
|
| 256 |
-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
| 257 |
-
|
| 258 |
-
# call the callback, if provided
|
| 259 |
-
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 260 |
-
progress_bar.update()
|
| 261 |
-
if callback is not None and i % callback_steps == 0:
|
| 262 |
-
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 263 |
-
callback(step_idx, t, latents)
|
| 264 |
-
|
| 265 |
-
self.maybe_free_model_hooks()
|
| 266 |
-
|
| 267 |
-
if self.latents_offset is not None:
|
| 268 |
-
latents = latents + torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
|
| 269 |
-
|
| 270 |
-
if not output_type == "latent":
|
| 271 |
-
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
| 272 |
-
image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
|
| 273 |
-
else:
|
| 274 |
-
image = latents
|
| 275 |
-
has_nsfw_concept = None
|
| 276 |
-
|
| 277 |
-
if has_nsfw_concept is None:
|
| 278 |
-
do_denormalize = [True] * image.shape[0]
|
| 279 |
-
else:
|
| 280 |
-
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
| 281 |
-
|
| 282 |
-
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
| 283 |
-
|
| 284 |
-
self.maybe_free_model_hooks()
|
| 285 |
-
|
| 286 |
-
if not return_dict:
|
| 287 |
-
return (image, has_nsfw_concept)
|
| 288 |
-
|
| 289 |
-
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
| 290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|