| | import inspect |
| | from typing import List, Optional, Union |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | import PIL |
| | from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, PNDMScheduler, UNet2DConditionModel |
| | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker |
| | from tqdm.auto import tqdm |
| | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer |
| |
|
| |
|
| | def preprocess_image(image): |
| | w, h = image.size |
| | w, h = map(lambda x: x - x % 32, (w, h)) |
| | image = image.resize((w, h), resample=PIL.Image.LANCZOS) |
| | image = np.array(image).astype(np.float32) / 255.0 |
| | image = image[None].transpose(0, 3, 1, 2) |
| | image = torch.from_numpy(image) |
| | return 2.0 * image - 1.0 |
| |
|
| |
|
| | def preprocess_mask(mask): |
| | mask = mask.convert("L") |
| | w, h = mask.size |
| | w, h = map(lambda x: x - x % 32, (w, h)) |
| | mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST) |
| | mask = np.array(mask).astype(np.float32) / 255.0 |
| | mask = np.tile(mask, (4, 1, 1)) |
| | mask = mask[None].transpose(0, 1, 2, 3) |
| | mask = 1 - mask |
| | mask = torch.from_numpy(mask) |
| | return mask |
| |
|
| | class StableDiffusionInpaintingPipeline(DiffusionPipeline): |
| | def __init__( |
| | self, |
| | vae: AutoencoderKL, |
| | text_encoder: CLIPTextModel, |
| | tokenizer: CLIPTokenizer, |
| | unet: UNet2DConditionModel, |
| | scheduler: Union[DDIMScheduler, PNDMScheduler], |
| | safety_checker: StableDiffusionSafetyChecker, |
| | feature_extractor: CLIPFeatureExtractor, |
| | ): |
| | super().__init__() |
| | scheduler = scheduler.set_format("pt") |
| | self.register_modules( |
| | vae=vae, |
| | text_encoder=text_encoder, |
| | tokenizer=tokenizer, |
| | unet=unet, |
| | scheduler=scheduler, |
| | safety_checker=safety_checker, |
| | feature_extractor=feature_extractor, |
| | ) |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, |
| | prompt: Union[str, List[str]], |
| | init_image: torch.FloatTensor, |
| | mask_image: torch.FloatTensor, |
| | strength: float = 0.8, |
| | num_inference_steps: Optional[int] = 50, |
| | guidance_scale: Optional[float] = 7.5, |
| | eta: Optional[float] = 0.0, |
| | generator: Optional[torch.Generator] = None, |
| | output_type: Optional[str] = "pil", |
| | ): |
| |
|
| | if isinstance(prompt, str): |
| | batch_size = 1 |
| | elif isinstance(prompt, list): |
| | batch_size = len(prompt) |
| | else: |
| | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") |
| |
|
| | if strength < 0 or strength > 1: |
| | raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") |
| |
|
| | |
| | accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) |
| | extra_set_kwargs = {} |
| | offset = 0 |
| | if accepts_offset: |
| | offset = 1 |
| | extra_set_kwargs["offset"] = 1 |
| |
|
| | self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) |
| |
|
| | |
| | init_image = preprocess_image(init_image).to(self.device) |
| |
|
| | |
| | init_latent_dist = self.vae.encode(init_image).latent_dist |
| | init_latents = init_latent_dist.sample(generator=generator) |
| | init_latents = 0.18215 * init_latents |
| |
|
| | |
| | init_latents = torch.cat([init_latents] * batch_size) |
| | init_latents_orig = init_latents |
| |
|
| | |
| | mask = preprocess_mask(mask_image).to(self.device) |
| | mask = torch.cat([mask] * batch_size) |
| |
|
| | |
| | if not mask.shape == init_latents.shape: |
| | raise ValueError(f"The mask and init_image should be the same size!") |
| |
|
| | |
| | init_timestep = int(num_inference_steps * strength) + offset |
| | init_timestep = min(init_timestep, num_inference_steps) |
| | timesteps = self.scheduler.timesteps[-init_timestep] |
| | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) |
| |
|
| | |
| | noise = torch.randn(init_latents.shape, generator=generator, device=self.device) |
| | init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) |
| |
|
| | |
| | text_input = self.tokenizer( |
| | prompt, |
| | padding="max_length", |
| | max_length=self.tokenizer.model_max_length, |
| | truncation=True, |
| | return_tensors="pt", |
| | ) |
| | text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] |
| |
|
| | |
| | |
| | |
| | do_classifier_free_guidance = guidance_scale > 1.0 |
| | |
| | if do_classifier_free_guidance: |
| | max_length = text_input.input_ids.shape[-1] |
| | uncond_input = self.tokenizer( |
| | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" |
| | ) |
| | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] |
| |
|
| | |
| | |
| | |
| | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
| |
|
| | |
| | |
| | |
| | |
| | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
| | extra_step_kwargs = {} |
| | if accepts_eta: |
| | extra_step_kwargs["eta"] = eta |
| |
|
| | latents = init_latents |
| | t_start = max(num_inference_steps - init_timestep + offset, 0) |
| | for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): |
| | |
| | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| |
|
| | |
| | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] |
| |
|
| | |
| | if do_classifier_free_guidance: |
| | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
| |
|
| | |
| | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"] |
| |
|
| | |
| | init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t) |
| | latents = (init_latents_proper * mask) + (latents * (1 - mask)) |
| |
|
| | |
| | latents = 1 / 0.18215 * latents |
| | image = self.vae.decode(latents).sample |
| |
|
| | image = (image / 2 + 0.5).clamp(0, 1) |
| | image = image.cpu().permute(0, 2, 3, 1).numpy() |
| |
|
| | |
| | safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) |
| | image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) |
| |
|
| | if output_type == "pil": |
| | image = self.numpy_to_pil(image) |
| |
|
| | return {"sample": image, "nsfw_content_detected": has_nsfw_concept} |