| | |
| |
|
| | import imp |
| | import numpy as np |
| | import cv2 |
| | import torch |
| | import random |
| | from PIL import Image, ImageDraw, ImageFont |
| | import copy |
| | from typing import Optional, Union, Tuple, List, Callable, Dict, Any |
| | from tqdm.notebook import tqdm |
| | from diffusers.utils import BaseOutput, logging |
| | from diffusers.models.embeddings import TimestepEmbedding, Timesteps |
| | from diffusers.models.unet_2d_blocks import ( |
| | CrossAttnDownBlock2D, |
| | CrossAttnUpBlock2D, |
| | DownBlock2D, |
| | UNetMidBlock2DCrossAttn, |
| | UpBlock2D, |
| | get_down_block, |
| | get_up_block, |
| | ) |
| | from diffusers.models.unet_2d_condition import UNet2DConditionOutput |
| | from copy import deepcopy |
| | import json |
| |
|
| | import inspect |
| | import os |
| | import warnings |
| | from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
| |
|
| | import numpy as np |
| | import PIL.Image |
| | import torch |
| | import torch.nn.functional as F |
| | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer |
| |
|
| | from diffusers.image_processor import VaeImageProcessor |
| | from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin |
| | from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel |
| | from diffusers.schedulers import KarrasDiffusionSchedulers |
| | from diffusers.utils.torch_utils import is_compiled_module |
| |
|
| | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput |
| | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker |
| | from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel |
| | from tqdm import tqdm |
| | from controlnet_aux import HEDdetector, OpenposeDetector |
| | import time |
| |
|
| | def seed_everything(seed): |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed(seed) |
| | random.seed(seed) |
| | np.random.seed(seed) |
| |
|
| | def get_promptls(prompt_path): |
| | with open(prompt_path) as f: |
| | prompt_ls = json.load(f) |
| | prompt_ls = [prompt['caption'].replace('/','_') for prompt in prompt_ls] |
| | return prompt_ls |
| |
|
| | def load_512(image_path, left=0, right=0, top=0, bottom=0): |
| | if type(image_path) is str: |
| | image = np.array(Image.open(image_path)) |
| | if image.ndim>3: |
| | image = image[:,:,:3] |
| | elif image.ndim == 2: |
| | image = image.reshape(image.shape[0], image.shape[1],1).astype('uint8') |
| | else: |
| | image = image_path |
| | h, w, c = image.shape |
| | left = min(left, w-1) |
| | right = min(right, w - left - 1) |
| | top = min(top, h - left - 1) |
| | bottom = min(bottom, h - top - 1) |
| | image = image[top:h-bottom, left:w-right] |
| | h, w, c = image.shape |
| | if h < w: |
| | offset = (w - h) // 2 |
| | image = image[:, offset:offset + h] |
| | elif w < h: |
| | offset = (h - w) // 2 |
| | image = image[offset:offset + w] |
| | image = np.array(Image.fromarray(image).resize((512, 512))) |
| | return image |
| |
|
| | def get_canny(image_path): |
| | image = load_512( |
| | image_path |
| | ) |
| | image = np.array(image) |
| |
|
| | |
| | image = cv2.Canny(image, 100, 200) |
| | image = image[:, :, None] |
| | image = np.concatenate([image, image, image], axis=2) |
| | canny_image = Image.fromarray(image) |
| | return canny_image |
| |
|
| |
|
| | def get_scribble(image_path, hed): |
| | image = load_512( |
| | image_path |
| | ) |
| | image = hed(image, scribble=True) |
| |
|
| | return image |
| |
|
| | def get_cocoimages(prompt_path): |
| | data_ls = [] |
| | with open(prompt_path) as f: |
| | prompt_ls = json.load(f) |
| | img_path = 'COCO2017-val/val2017' |
| | for prompt in tqdm(prompt_ls): |
| | caption = prompt['caption'].replace('/','_') |
| | image_id = str(prompt['image_id']) |
| | image_id = (12-len(image_id))*'0' + image_id+'.jpg' |
| | image_path = os.path.join(img_path, image_id) |
| | try: |
| | image = get_canny(image_path) |
| | except: |
| | continue |
| | curr_data = {'image':image, 'prompt':caption} |
| | data_ls.append(curr_data) |
| | return data_ls |
| |
|
| | def get_cocoimages2(prompt_path): |
| | """scribble condition |
| | """ |
| | data_ls = [] |
| | with open(prompt_path) as f: |
| | prompt_ls = json.load(f) |
| | img_path = 'COCO2017-val/val2017' |
| | hed = HEDdetector.from_pretrained('ControlNet/detector_weights/annotator', filename='network-bsds500.pth') |
| | for prompt in tqdm(prompt_ls): |
| | caption = prompt['caption'].replace('/','_') |
| | image_id = str(prompt['image_id']) |
| | image_id = (12-len(image_id))*'0' + image_id+'.jpg' |
| | image_path = os.path.join(img_path, image_id) |
| | try: |
| | image = get_scribble(image_path,hed) |
| | except: |
| | continue |
| | curr_data = {'image':image, 'prompt':caption} |
| | data_ls.append(curr_data) |
| | return data_ls |
| |
|
| | def warpped_feature(sample, step): |
| | """ |
| | sample: batch_size*dim*h*w, uncond: 0 - batch_size//2, cond: batch_size//2 - batch_size |
| | step: timestep span |
| | """ |
| | bs, dim, h, w = sample.shape |
| | uncond_fea, cond_fea = sample.chunk(2) |
| | uncond_fea = uncond_fea.repeat(step,1,1,1) |
| | cond_fea = cond_fea.repeat(step,1,1,1) |
| | return torch.cat([uncond_fea, cond_fea]) |
| |
|
| | def warpped_skip_feature(block_samples, step): |
| | down_block_res_samples = [] |
| | for sample in block_samples: |
| | sample_expand = warpped_feature(sample, step) |
| | down_block_res_samples.append(sample_expand) |
| | return tuple(down_block_res_samples) |
| |
|
| | def warpped_text_emb(text_emb, step): |
| | """ |
| | text_emb: batch_size*77*768, uncond: 0 - batch_size//2, cond: batch_size//2 - batch_size |
| | step: timestep span |
| | """ |
| | bs, token_len, dim = text_emb.shape |
| | uncond_fea, cond_fea = text_emb.chunk(2) |
| | uncond_fea = uncond_fea.repeat(step,1,1) |
| | cond_fea = cond_fea.repeat(step,1,1) |
| | return torch.cat([uncond_fea, cond_fea]) |
| |
|
| | def warpped_timestep(timesteps, bs): |
| | """ |
| | timestpes: list, such as [981, 961, 941] |
| | """ |
| | semi_bs = bs//2 |
| | ts = [] |
| | for timestep in timesteps: |
| | timestep = timestep[None] |
| | texp = timestep.expand(semi_bs) |
| | ts.append(texp) |
| | timesteps = torch.cat(ts) |
| | return timesteps.repeat(2,1).reshape(-1) |
| |
|
| | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): |
| | """ |
| | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and |
| | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 |
| | """ |
| | std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) |
| | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) |
| | |
| | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) |
| | |
| | noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg |
| | return noise_cfg |
| |
|
| | def register_normal_pipeline(pipe): |
| | def new_call(self): |
| | @torch.no_grad() |
| | def call( |
| | prompt: Union[str, List[str]] = None, |
| | height: Optional[int] = None, |
| | width: Optional[int] = None, |
| | num_inference_steps: int = 50, |
| | guidance_scale: float = 7.5, |
| | negative_prompt: Optional[Union[str, List[str]]] = None, |
| | num_images_per_prompt: Optional[int] = 1, |
| | eta: float = 0.0, |
| | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| | latents: Optional[torch.FloatTensor] = None, |
| | prompt_embeds: Optional[torch.FloatTensor] = None, |
| | negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
| | output_type: Optional[str] = "pil", |
| | return_dict: bool = True, |
| | cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| | guidance_rescale: float = 0.0, |
| | clip_skip: Optional[int] = None, |
| | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
| | callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
| | **kwargs, |
| | ): |
| |
|
| | callback = kwargs.pop("callback", None) |
| | callback_steps = kwargs.pop("callback_steps", None) |
| |
|
| |
|
| | |
| | height = height or self.unet.config.sample_size * self.vae_scale_factor |
| | width = width or self.unet.config.sample_size * self.vae_scale_factor |
| | |
| |
|
| | |
| | self.check_inputs( |
| | prompt, |
| | height, |
| | width, |
| | callback_steps, |
| | negative_prompt, |
| | prompt_embeds, |
| | negative_prompt_embeds, |
| | callback_on_step_end_tensor_inputs, |
| | ) |
| |
|
| | self._guidance_scale = guidance_scale |
| | self._guidance_rescale = guidance_rescale |
| | self._clip_skip = clip_skip |
| | self._cross_attention_kwargs = cross_attention_kwargs |
| |
|
| | |
| | if prompt is not None and isinstance(prompt, str): |
| | batch_size = 1 |
| | elif prompt is not None and isinstance(prompt, list): |
| | batch_size = len(prompt) |
| | else: |
| | batch_size = prompt_embeds.shape[0] |
| |
|
| | device = self._execution_device |
| |
|
| | |
| | lora_scale = ( |
| | self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None |
| | ) |
| |
|
| | prompt_embeds, negative_prompt_embeds = self.encode_prompt( |
| | prompt, |
| | device, |
| | num_images_per_prompt, |
| | self.do_classifier_free_guidance, |
| | negative_prompt, |
| | prompt_embeds=prompt_embeds, |
| | negative_prompt_embeds=negative_prompt_embeds, |
| | lora_scale=lora_scale, |
| | clip_skip=self.clip_skip, |
| | ) |
| | |
| | |
| | |
| | if self.do_classifier_free_guidance: |
| | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) |
| |
|
| | |
| | self.scheduler.set_timesteps(num_inference_steps, device=device) |
| | timesteps = self.scheduler.timesteps |
| |
|
| | |
| | num_channels_latents = self.unet.config.in_channels |
| | latents = self.prepare_latents( |
| | batch_size * num_images_per_prompt, |
| | num_channels_latents, |
| | height, |
| | width, |
| | prompt_embeds.dtype, |
| | device, |
| | generator, |
| | latents, |
| | ) |
| |
|
| | |
| | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
| |
|
| | |
| | timestep_cond = None |
| | if self.unet.config.time_cond_proj_dim is not None: |
| | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) |
| | timestep_cond = self.get_guidance_scale_embedding( |
| | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim |
| | ).to(device=device, dtype=latents.dtype) |
| |
|
| | |
| | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
| | self._num_timesteps = len(timesteps) |
| | init_latents = latents.detach().clone() |
| | with self.progress_bar(total=num_inference_steps) as progress_bar: |
| | for i, t in enumerate(timesteps): |
| | if t/1000 < 0.5: |
| | latents = latents + 0.003*init_latents |
| | setattr(self.unet, 'order', i) |
| | |
| | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents |
| | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
| |
|
| | |
| | noise_pred = self.unet( |
| | latent_model_input, |
| | t, |
| | encoder_hidden_states=prompt_embeds, |
| | timestep_cond=timestep_cond, |
| | cross_attention_kwargs=self.cross_attention_kwargs, |
| | return_dict=False, |
| | )[0] |
| |
|
| | |
| | if self.do_classifier_free_guidance: |
| | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) |
| |
|
| | if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: |
| | |
| | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) |
| |
|
| | |
| | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] |
| |
|
| | if callback_on_step_end is not None: |
| | callback_kwargs = {} |
| | for k in callback_on_step_end_tensor_inputs: |
| | callback_kwargs[k] = locals()[k] |
| | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) |
| |
|
| | latents = callback_outputs.pop("latents", latents) |
| | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) |
| | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) |
| |
|
| | |
| | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
| | progress_bar.update() |
| | if callback is not None and i % callback_steps == 0: |
| | step_idx = i // getattr(self.scheduler, "order", 1) |
| | callback(step_idx, t, latents) |
| |
|
| | if not output_type == "latent": |
| | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ |
| | 0 |
| | ] |
| | image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) |
| | else: |
| | image = latents |
| | has_nsfw_concept = None |
| |
|
| | if has_nsfw_concept is None: |
| | do_denormalize = [True] * image.shape[0] |
| | else: |
| | do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] |
| |
|
| | image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) |
| |
|
| | |
| | self.maybe_free_model_hooks() |
| |
|
| | if not return_dict: |
| | return (image, has_nsfw_concept) |
| |
|
| | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) |
| | return call |
| | pipe.call = new_call(pipe) |
| |
|
| | def register_parallel_pipeline(pipe, mod = '50ls'): |
| | def new_call(self): |
| | @torch.no_grad() |
| | def call( |
| | prompt: Union[str, List[str]] = None, |
| | height: Optional[int] = None, |
| | width: Optional[int] = None, |
| | num_inference_steps: int = 50, |
| | guidance_scale: float = 7.5, |
| | negative_prompt: Optional[Union[str, List[str]]] = None, |
| | num_images_per_prompt: Optional[int] = 1, |
| | eta: float = 0.0, |
| | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| | latents: Optional[torch.FloatTensor] = None, |
| | prompt_embeds: Optional[torch.FloatTensor] = None, |
| | negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
| | output_type: Optional[str] = "pil", |
| | return_dict: bool = True, |
| | cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| | guidance_rescale: float = 0.0, |
| | clip_skip: Optional[int] = None, |
| | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
| | callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
| | **kwargs, |
| | ): |
| |
|
| | callback = kwargs.pop("callback", None) |
| | callback_steps = kwargs.pop("callback_steps", None) |
| |
|
| |
|
| | |
| | height = height or self.unet.config.sample_size * self.vae_scale_factor |
| | width = width or self.unet.config.sample_size * self.vae_scale_factor |
| | |
| |
|
| | |
| | self.check_inputs( |
| | prompt, |
| | height, |
| | width, |
| | callback_steps, |
| | negative_prompt, |
| | prompt_embeds, |
| | negative_prompt_embeds, |
| | callback_on_step_end_tensor_inputs, |
| | ) |
| |
|
| | self._guidance_scale = guidance_scale |
| | self._guidance_rescale = guidance_rescale |
| | self._clip_skip = clip_skip |
| | self._cross_attention_kwargs = cross_attention_kwargs |
| |
|
| | |
| | if prompt is not None and isinstance(prompt, str): |
| | batch_size = 1 |
| | elif prompt is not None and isinstance(prompt, list): |
| | batch_size = len(prompt) |
| | else: |
| | batch_size = prompt_embeds.shape[0] |
| |
|
| | device = self._execution_device |
| |
|
| | |
| | lora_scale = ( |
| | self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None |
| | ) |
| |
|
| | prompt_embeds, negative_prompt_embeds = self.encode_prompt( |
| | prompt, |
| | device, |
| | num_images_per_prompt, |
| | self.do_classifier_free_guidance, |
| | negative_prompt, |
| | prompt_embeds=prompt_embeds, |
| | negative_prompt_embeds=negative_prompt_embeds, |
| | lora_scale=lora_scale, |
| | clip_skip=self.clip_skip, |
| | ) |
| | |
| | |
| | |
| | if self.do_classifier_free_guidance: |
| | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) |
| |
|
| | |
| | self.scheduler.set_timesteps(num_inference_steps, device=device) |
| | timesteps = self.scheduler.timesteps |
| |
|
| | |
| | num_channels_latents = self.unet.config.in_channels |
| | latents = self.prepare_latents( |
| | batch_size * num_images_per_prompt, |
| | num_channels_latents, |
| | height, |
| | width, |
| | prompt_embeds.dtype, |
| | device, |
| | generator, |
| | latents, |
| | ) |
| |
|
| | |
| | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
| |
|
| | |
| | timestep_cond = None |
| | if self.unet.config.time_cond_proj_dim is not None: |
| | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) |
| | timestep_cond = self.get_guidance_scale_embedding( |
| | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim |
| | ).to(device=device, dtype=latents.dtype) |
| |
|
| | |
| | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
| | self._num_timesteps = len(timesteps) |
| | init_latents = latents.detach().clone() |
| | |
| | all_steps = len(self.scheduler.timesteps) |
| | curr_step = 0 |
| | if mod == '50ls': |
| | cond = lambda timestep: timestep in [0,1,2,3,5,10,15,25,35] |
| | elif isinstance(mod, int): |
| | cond = lambda timestep: timestep % mod ==0 |
| | else: |
| | raise Exception("Currently not supported, But you can modify the code to customize the keytime") |
| | while curr_step<all_steps: |
| | register_time(self.unet, curr_step) |
| | time_ls = [self.scheduler.timesteps[curr_step]] |
| | curr_step += 1 |
| | while not cond(curr_step): |
| | if curr_step<all_steps: |
| | time_ls.append(self.scheduler.timesteps[curr_step]) |
| | curr_step += 1 |
| | else: |
| | break |
| |
|
| | |
| | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents |
| |
|
| | |
| | noise_pred = self.unet( |
| | latent_model_input, |
| | time_ls, |
| | encoder_hidden_states=prompt_embeds, |
| | timestep_cond=timestep_cond, |
| | cross_attention_kwargs=self.cross_attention_kwargs, |
| | return_dict=False, |
| | )[0] |
| |
|
| | |
| | if self.do_classifier_free_guidance: |
| | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) |
| |
|
| | if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: |
| | |
| | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) |
| |
|
| | |
| | bs = noise_pred.shape[0] |
| | bs_perstep = bs//len(time_ls) |
| |
|
| | denoised_latent = latents |
| | for i, timestep in enumerate(time_ls): |
| | if timestep/1000 < 0.5: |
| | denoised_latent = denoised_latent + 0.003*init_latents |
| | curr_noise = noise_pred[i*bs_perstep:(i+1)*bs_perstep] |
| | denoised_latent = self.scheduler.step(curr_noise, timestep, denoised_latent, **extra_step_kwargs, return_dict=False)[0] |
| | |
| | latents = denoised_latent |
| | |
| |
|
| | |
| | if not output_type == "latent": |
| | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ |
| | 0 |
| | ] |
| | image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) |
| | else: |
| | image = latents |
| | has_nsfw_concept = None |
| |
|
| | if has_nsfw_concept is None: |
| | do_denormalize = [True] * image.shape[0] |
| | else: |
| | do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] |
| |
|
| | image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) |
| |
|
| | |
| | self.maybe_free_model_hooks() |
| |
|
| | if not return_dict: |
| | return (image, has_nsfw_concept) |
| |
|
| | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) |
| | return call |
| | pipe.call = new_call(pipe) |
| |
|
| | def register_faster_forward(model, mod = '50ls'): |
| | def faster_forward(self): |
| | def forward( |
| | sample: torch.FloatTensor, |
| | timestep: Union[torch.Tensor, float, int], |
| | encoder_hidden_states: torch.Tensor, |
| | class_labels: Optional[torch.Tensor] = None, |
| | timestep_cond: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, |
| | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, |
| | mid_block_additional_residual: Optional[torch.Tensor] = None, |
| | down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, |
| | return_dict: bool = True, |
| | ) -> Union[UNet2DConditionOutput, Tuple]: |
| | r""" |
| | Args: |
| | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor |
| | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps |
| | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states |
| | return_dict (`bool`, *optional*, defaults to `True`): |
| | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. |
| | cross_attention_kwargs (`dict`, *optional*): |
| | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
| | `self.processor` in |
| | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). |
| | |
| | Returns: |
| | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: |
| | [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When |
| | returning a tuple, the first element is the sample tensor. |
| | """ |
| | |
| | |
| | |
| | |
| | default_overall_up_factor = 2**self.num_upsamplers |
| |
|
| | |
| | forward_upsample_size = False |
| | upsample_size = None |
| |
|
| | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): |
| | print("Forward upsample size to force interpolation output size.") |
| | forward_upsample_size = True |
| |
|
| | |
| | if attention_mask is not None: |
| | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 |
| | attention_mask = attention_mask.unsqueeze(1) |
| |
|
| | |
| | if self.config.center_input_sample: |
| | sample = 2 * sample - 1.0 |
| |
|
| | |
| | if isinstance(timestep, list): |
| | timesteps = timestep[0] |
| | step = len(timestep) |
| | else: |
| | timesteps = timestep |
| | step = 1 |
| | if not torch.is_tensor(timesteps) and (not isinstance(timesteps,list)): |
| | |
| | |
| | is_mps = sample.device.type == "mps" |
| | if isinstance(timestep, float): |
| | dtype = torch.float32 if is_mps else torch.float64 |
| | else: |
| | dtype = torch.int32 if is_mps else torch.int64 |
| | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) |
| | elif (not isinstance(timesteps,list)) and len(timesteps.shape) == 0: |
| | timesteps = timesteps[None].to(sample.device) |
| | |
| | if (not isinstance(timesteps,list)) and len(timesteps.shape) == 1: |
| | |
| | timesteps = timesteps.expand(sample.shape[0]) |
| | elif isinstance(timesteps, list): |
| | |
| | timesteps = warpped_timestep(timesteps, sample.shape[0]).to(sample.device) |
| | t_emb = self.time_proj(timesteps) |
| |
|
| | |
| | |
| | |
| | t_emb = t_emb.to(dtype=self.dtype) |
| |
|
| | emb = self.time_embedding(t_emb, timestep_cond) |
| |
|
| | if self.class_embedding is not None: |
| | if class_labels is None: |
| | raise ValueError("class_labels should be provided when num_class_embeds > 0") |
| |
|
| | if self.config.class_embed_type == "timestep": |
| | class_labels = self.time_proj(class_labels) |
| |
|
| | |
| | |
| | class_labels = class_labels.to(dtype=sample.dtype) |
| |
|
| | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) |
| |
|
| | if self.config.class_embeddings_concat: |
| | emb = torch.cat([emb, class_emb], dim=-1) |
| | else: |
| | emb = emb + class_emb |
| |
|
| | if self.config.addition_embed_type == "text": |
| | aug_emb = self.add_embedding(encoder_hidden_states) |
| | emb = emb + aug_emb |
| |
|
| | if self.time_embed_act is not None: |
| | emb = self.time_embed_act(emb) |
| |
|
| | if self.encoder_hid_proj is not None: |
| | encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) |
| |
|
| | |
| | order = self.order |
| | |
| | ipow = int(np.sqrt(9 + 8*order)) |
| | cond = order in [0, 1, 2, 3, 5, 10, 15, 25, 35] |
| | if isinstance(mod, int): |
| | cond = order % mod == 0 |
| | elif mod == "pro": |
| | cond = ipow * ipow == (9 + 8 * order) |
| | elif mod == "50ls": |
| | cond = order in [0, 1, 2, 3, 5, 10, 15, 25, 35] |
| | elif mod == "50ls2": |
| | cond = order in [0, 10, 11, 12, 15, 20, 25, 30,35,45] |
| | elif mod == "50ls3": |
| | cond = order in [0, 20, 25, 30,35,45,46,47,48,49] |
| | elif mod == "50ls4": |
| | cond = order in [0, 9, 13, 14, 15, 28, 29, 32, 36,45] |
| | elif mod == "100ls": |
| | cond = order > 85 or order < 10 or order % 5 == 0 |
| | elif mod == "75ls": |
| | cond = order > 65 or order < 10 or order % 5 == 0 |
| | elif mod == "s2": |
| | cond = order < 20 or order > 40 or order % 2 == 0 |
| |
|
| | if cond: |
| | |
| | |
| | sample = self.conv_in(sample) |
| |
|
| | |
| | down_block_res_samples = (sample,) |
| | for downsample_block in self.down_blocks: |
| | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: |
| | |
| | additional_residuals = {} |
| | if down_intrablock_additional_residuals is not None and len(down_intrablock_additional_residuals) > 0: |
| | additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) |
| | sample, res_samples = downsample_block( |
| | hidden_states=sample, |
| | temb=emb, |
| | encoder_hidden_states=encoder_hidden_states, |
| | attention_mask=attention_mask, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | **additional_residuals |
| | ) |
| | else: |
| | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) |
| | if down_intrablock_additional_residuals is not None and len(down_intrablock_additional_residuals) > 0: |
| | sample += down_intrablock_additional_residuals.pop(0) |
| |
|
| | down_block_res_samples += res_samples |
| |
|
| | |
| |
|
| | if down_block_additional_residuals is not None: |
| | new_down_block_res_samples = () |
| |
|
| | for down_block_res_sample, down_block_additional_residual in zip( |
| | down_block_res_samples, down_block_additional_residuals |
| | ): |
| | down_block_res_sample = down_block_res_sample + down_block_additional_residual |
| | new_down_block_res_samples += (down_block_res_sample,) |
| |
|
| | down_block_res_samples = new_down_block_res_samples |
| |
|
| | |
| | if down_block_additional_residuals is not None: |
| | new_down_block_res_samples = () |
| | for down_block_res_sample, down_block_additional_residual in zip( |
| | down_block_res_samples, down_block_additional_residuals |
| | ): |
| | down_block_res_sample = down_block_res_sample + down_block_additional_residual |
| | new_down_block_res_samples += (down_block_res_sample,) |
| | down_block_res_samples = new_down_block_res_samples |
| | |
| | |
| | if self.mid_block is not None: |
| | if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: |
| | sample = self.mid_block( |
| | sample, |
| | emb, |
| | encoder_hidden_states=encoder_hidden_states, |
| | attention_mask=attention_mask, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | ) |
| | else: |
| | sample = self.mid_block(sample, emb) |
| | |
| | if down_intrablock_additional_residuals is not None and len(down_intrablock_additional_residuals) > 0 and sample.shape == down_intrablock_additional_residuals[0].shape: |
| | sample += down_intrablock_additional_residuals.pop(0) |
| |
|
| | if mid_block_additional_residual is not None: |
| | sample = sample + mid_block_additional_residual |
| |
|
| | |
| | |
| | setattr(self, 'skip_feature', deepcopy(down_block_res_samples)) |
| | setattr(self, 'toup_feature', sample.detach().clone()) |
| | |
| |
|
| |
|
| |
|
| | |
| | if isinstance(timestep, list): |
| | |
| | timesteps = warpped_timestep(timestep, sample.shape[0]).to(sample.device) |
| | t_emb = self.time_proj(timesteps) |
| |
|
| | |
| | |
| | |
| | t_emb = t_emb.to(dtype=self.dtype) |
| |
|
| | emb = self.time_embedding(t_emb, timestep_cond) |
| |
|
| | down_block_res_samples = warpped_skip_feature(down_block_res_samples, step) |
| | sample = warpped_feature(sample, step) |
| | encoder_hidden_states = warpped_text_emb(encoder_hidden_states, step) |
| | |
| | |
| | else: |
| | down_block_res_samples = self.skip_feature |
| | sample = self.toup_feature |
| |
|
| | |
| | down_block_res_samples = warpped_skip_feature(down_block_res_samples, step) |
| | sample = warpped_feature(sample, step) |
| | encoder_hidden_states = warpped_text_emb(encoder_hidden_states, step) |
| | |
| |
|
| | |
| | for i, upsample_block in enumerate(self.up_blocks): |
| | is_final_block = i == len(self.up_blocks) - 1 |
| |
|
| | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] |
| | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] |
| |
|
| | |
| | |
| | if not is_final_block and forward_upsample_size: |
| | upsample_size = down_block_res_samples[-1].shape[2:] |
| |
|
| | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: |
| | sample = upsample_block( |
| | hidden_states=sample, |
| | temb=emb, |
| | res_hidden_states_tuple=res_samples, |
| | encoder_hidden_states=encoder_hidden_states, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | upsample_size=upsample_size, |
| | attention_mask=attention_mask, |
| | ) |
| | else: |
| | sample = upsample_block( |
| | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size |
| | ) |
| |
|
| | |
| | if self.conv_norm_out: |
| | sample = self.conv_norm_out(sample) |
| | sample = self.conv_act(sample) |
| | sample = self.conv_out(sample) |
| |
|
| | if not return_dict: |
| | return (sample,) |
| |
|
| | return UNet2DConditionOutput(sample=sample) |
| | return forward |
| | if model.__class__.__name__ == 'UNet2DConditionModel': |
| | model.forward = faster_forward(model) |
| |
|
| | def register_normal_forward(model): |
| | def normal_forward(self): |
| | def forward( |
| | sample: torch.FloatTensor, |
| | timestep: Union[torch.Tensor, float, int], |
| | encoder_hidden_states: torch.Tensor, |
| | class_labels: Optional[torch.Tensor] = None, |
| | timestep_cond: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, |
| | mid_block_additional_residual: Optional[torch.Tensor] = None, |
| | return_dict: bool = True, |
| | ) -> Union[UNet2DConditionOutput, Tuple]: |
| | |
| | |
| | |
| | |
| | default_overall_up_factor = 2**self.num_upsamplers |
| |
|
| | |
| | forward_upsample_size = False |
| | upsample_size = None |
| | |
| | |
| | |
| | |
| | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): |
| | print("Forward upsample size to force interpolation output size.") |
| | forward_upsample_size = True |
| |
|
| | |
| | if attention_mask is not None: |
| | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 |
| | attention_mask = attention_mask.unsqueeze(1) |
| |
|
| | |
| | if self.config.center_input_sample: |
| | sample = 2 * sample - 1.0 |
| |
|
| | |
| | timesteps = timestep |
| | if not torch.is_tensor(timesteps): |
| | |
| | |
| | is_mps = sample.device.type == "mps" |
| | if isinstance(timestep, float): |
| | dtype = torch.float32 if is_mps else torch.float64 |
| | else: |
| | dtype = torch.int32 if is_mps else torch.int64 |
| | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) |
| | elif len(timesteps.shape) == 0: |
| | timesteps = timesteps[None].to(sample.device) |
| |
|
| | |
| | timesteps = timesteps.expand(sample.shape[0]) |
| |
|
| | t_emb = self.time_proj(timesteps) |
| |
|
| | |
| | |
| | |
| | t_emb = t_emb.to(dtype=self.dtype) |
| |
|
| | emb = self.time_embedding(t_emb, timestep_cond) |
| |
|
| | if self.class_embedding is not None: |
| | if class_labels is None: |
| | raise ValueError("class_labels should be provided when num_class_embeds > 0") |
| |
|
| | if self.config.class_embed_type == "timestep": |
| | class_labels = self.time_proj(class_labels) |
| |
|
| | |
| | |
| | class_labels = class_labels.to(dtype=sample.dtype) |
| |
|
| | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) |
| |
|
| | if self.config.class_embeddings_concat: |
| | emb = torch.cat([emb, class_emb], dim=-1) |
| | else: |
| | emb = emb + class_emb |
| |
|
| | if self.config.addition_embed_type == "text": |
| | aug_emb = self.add_embedding(encoder_hidden_states) |
| | emb = emb + aug_emb |
| |
|
| | if self.time_embed_act is not None: |
| | emb = self.time_embed_act(emb) |
| |
|
| | if self.encoder_hid_proj is not None: |
| | encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) |
| |
|
| | |
| | sample = self.conv_in(sample) |
| |
|
| | |
| | down_block_res_samples = (sample,) |
| | for i, downsample_block in enumerate(self.down_blocks): |
| | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: |
| | sample, res_samples = downsample_block( |
| | hidden_states=sample, |
| | temb=emb, |
| | encoder_hidden_states=encoder_hidden_states, |
| | attention_mask=attention_mask, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | ) |
| | else: |
| | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) |
| | |
| | |
| | |
| | down_block_res_samples += res_samples |
| |
|
| | if down_block_additional_residuals is not None: |
| | new_down_block_res_samples = () |
| |
|
| | for down_block_res_sample, down_block_additional_residual in zip( |
| | down_block_res_samples, down_block_additional_residuals |
| | ): |
| | down_block_res_sample = down_block_res_sample + down_block_additional_residual |
| | new_down_block_res_samples += (down_block_res_sample,) |
| |
|
| | down_block_res_samples = new_down_block_res_samples |
| |
|
| | |
| | if self.mid_block is not None: |
| | sample = self.mid_block( |
| | sample, |
| | emb, |
| | encoder_hidden_states=encoder_hidden_states, |
| | attention_mask=attention_mask, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | ) |
| | |
| | if mid_block_additional_residual is not None: |
| | sample = sample + mid_block_additional_residual |
| | |
| | for i, upsample_block in enumerate(self.up_blocks): |
| | is_final_block = i == len(self.up_blocks) - 1 |
| |
|
| | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] |
| | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] |
| |
|
| | |
| | |
| | if not is_final_block and forward_upsample_size: |
| | upsample_size = down_block_res_samples[-1].shape[2:] |
| |
|
| | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: |
| | sample = upsample_block( |
| | hidden_states=sample, |
| | temb=emb, |
| | res_hidden_states_tuple=res_samples, |
| | encoder_hidden_states=encoder_hidden_states, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | upsample_size=upsample_size, |
| | attention_mask=attention_mask, |
| | ) |
| | else: |
| | sample = upsample_block( |
| | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size |
| | ) |
| | |
| | |
| | |
| | |
| | if self.conv_norm_out: |
| | sample = self.conv_norm_out(sample) |
| | sample = self.conv_act(sample) |
| | sample = self.conv_out(sample) |
| |
|
| | if not return_dict: |
| | return (sample,) |
| |
|
| | return UNet2DConditionOutput(sample=sample) |
| | return forward |
| | if model.__class__.__name__ == 'UNet2DConditionModel': |
| | model.forward = normal_forward(model) |
| |
|
| | def register_time(unet, t): |
| | setattr(unet, 'order', t) |
| |
|
| | def register_controlnet_pipeline(pipe): |
| | def new_call(self): |
| | @torch.no_grad() |
| | def call( |
| | prompt: Union[str, List[str]] = None, |
| | image: Union[ |
| | torch.FloatTensor, |
| | PIL.Image.Image, |
| | np.ndarray, |
| | List[torch.FloatTensor], |
| | List[PIL.Image.Image], |
| | List[np.ndarray], |
| | ] = None, |
| | height: Optional[int] = None, |
| | width: Optional[int] = None, |
| | num_inference_steps: int = 50, |
| | guidance_scale: float = 7.5, |
| | negative_prompt: Optional[Union[str, List[str]]] = None, |
| | num_images_per_prompt: Optional[int] = 1, |
| | eta: float = 0.0, |
| | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| | latents: Optional[torch.FloatTensor] = None, |
| | prompt_embeds: Optional[torch.FloatTensor] = None, |
| | negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
| | output_type: Optional[str] = "pil", |
| | return_dict: bool = True, |
| | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
| | callback_steps: int = 1, |
| | cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| | controlnet_conditioning_scale: Union[float, List[float]] = 1.0, |
| | guess_mode: bool = False, |
| | ): |
| | |
| | self.check_inputs( |
| | prompt, |
| | image, |
| | callback_steps, |
| | negative_prompt, |
| | prompt_embeds, |
| | negative_prompt_embeds, |
| | controlnet_conditioning_scale, |
| | ) |
| |
|
| | |
| | if prompt is not None and isinstance(prompt, str): |
| | batch_size = 1 |
| | elif prompt is not None and isinstance(prompt, list): |
| | batch_size = len(prompt) |
| | else: |
| | batch_size = prompt_embeds.shape[0] |
| |
|
| | device = self._execution_device |
| | |
| | |
| | |
| | do_classifier_free_guidance = guidance_scale > 1.0 |
| |
|
| | controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet |
| |
|
| | if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): |
| | controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) |
| |
|
| | global_pool_conditions = ( |
| | controlnet.config.global_pool_conditions |
| | if isinstance(controlnet, ControlNetModel) |
| | else controlnet.nets[0].config.global_pool_conditions |
| | ) |
| | guess_mode = guess_mode or global_pool_conditions |
| |
|
| | |
| | text_encoder_lora_scale = ( |
| | cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None |
| | ) |
| | prompt_embeds = self._encode_prompt( |
| | prompt, |
| | device, |
| | num_images_per_prompt, |
| | do_classifier_free_guidance, |
| | negative_prompt, |
| | prompt_embeds=prompt_embeds, |
| | negative_prompt_embeds=negative_prompt_embeds, |
| | lora_scale=text_encoder_lora_scale, |
| | ) |
| |
|
| | |
| | if isinstance(controlnet, ControlNetModel): |
| | image = self.prepare_image( |
| | image=image, |
| | width=width, |
| | height=height, |
| | batch_size=batch_size * num_images_per_prompt, |
| | num_images_per_prompt=num_images_per_prompt, |
| | device=device, |
| | dtype=controlnet.dtype, |
| | do_classifier_free_guidance=do_classifier_free_guidance, |
| | guess_mode=guess_mode, |
| | ) |
| | height, width = image.shape[-2:] |
| | elif isinstance(controlnet, MultiControlNetModel): |
| | images = [] |
| |
|
| | for image_ in image: |
| | image_ = self.prepare_image( |
| | image=image_, |
| | width=width, |
| | height=height, |
| | batch_size=batch_size * num_images_per_prompt, |
| | num_images_per_prompt=num_images_per_prompt, |
| | device=device, |
| | dtype=controlnet.dtype, |
| | do_classifier_free_guidance=do_classifier_free_guidance, |
| | guess_mode=guess_mode, |
| | ) |
| |
|
| | images.append(image_) |
| |
|
| | image = images |
| | height, width = image[0].shape[-2:] |
| | else: |
| | assert False |
| |
|
| | |
| | self.scheduler.set_timesteps(num_inference_steps, device=device) |
| | timesteps = self.scheduler.timesteps |
| |
|
| | |
| | num_channels_latents = self.unet.config.in_channels |
| | latents = self.prepare_latents( |
| | batch_size * num_images_per_prompt, |
| | num_channels_latents, |
| | height, |
| | width, |
| | prompt_embeds.dtype, |
| | device, |
| | generator, |
| | latents, |
| | ) |
| | self.init_latent = latents.detach().clone() |
| | |
| | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
| |
|
| | |
| | |
| | all_steps = len(self.scheduler.timesteps) |
| | curr_span = 1 |
| | curr_step = 0 |
| |
|
| | idx = 1 |
| | keytime = [0,1,2,3,5,10,15,25,35] |
| | keytime.append(all_steps) |
| | while curr_step<all_steps: |
| | register_time(self.unet, curr_step) |
| |
|
| | if curr_span>0: |
| | time_ls = [] |
| | for i in range(curr_step, curr_step+curr_span): |
| | if i<all_steps: |
| | time_ls.append(self.scheduler.timesteps[i]) |
| | else: |
| | break |
| |
|
| | |
| | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| | latent_model_input = self.scheduler.scale_model_input(latent_model_input, time_ls[0]) |
| | |
| | if curr_step in [0,1,2,3,5,10,15,25,35]: |
| | |
| | control_model_input = latent_model_input |
| | controlnet_prompt_embeds = prompt_embeds |
| |
|
| | down_block_res_samples, mid_block_res_sample = self.controlnet( |
| | control_model_input, |
| | time_ls[0], |
| | encoder_hidden_states=controlnet_prompt_embeds, |
| | controlnet_cond=image, |
| | conditioning_scale=controlnet_conditioning_scale, |
| | guess_mode=guess_mode, |
| | return_dict=False, |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | else: |
| | down_block_res_samples = None |
| | mid_block_res_sample = None |
| | |
| | noise_pred = self.unet( |
| | latent_model_input, |
| | time_ls, |
| | encoder_hidden_states=prompt_embeds, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | down_block_additional_residuals=down_block_res_samples, |
| | mid_block_additional_residual=mid_block_res_sample, |
| | return_dict=False, |
| | )[0] |
| |
|
| | |
| | if do_classifier_free_guidance: |
| | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
| |
|
| | |
| |
|
| | if isinstance(time_ls, list): |
| | step_span = len(time_ls) |
| | bs = noise_pred.shape[0] |
| | bs_perstep = bs//step_span |
| |
|
| | denoised_latent = latents |
| | for i, timestep in enumerate(time_ls): |
| | if timestep/1000 < 0.5: |
| | denoised_latent = denoised_latent + 0.003*self.init_latent |
| | curr_noise = noise_pred[i*bs_perstep:(i+1)*bs_perstep] |
| | denoised_latent = self.scheduler.step(curr_noise, timestep, denoised_latent, **extra_step_kwargs, return_dict=False)[0] |
| | |
| | latents = denoised_latent |
| | |
| | curr_step += curr_span |
| | idx += 1 |
| | if curr_step<all_steps: |
| | curr_span = keytime[idx] - keytime[idx-1] |
| |
|
| | |
| | |
| | |
| | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: |
| | self.unet.to("cpu") |
| | self.controlnet.to("cpu") |
| | torch.cuda.empty_cache() |
| |
|
| | if not output_type == "latent": |
| | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] |
| | image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) |
| | else: |
| | image = latents |
| | has_nsfw_concept = None |
| |
|
| | if has_nsfw_concept is None: |
| | do_denormalize = [True] * image.shape[0] |
| | else: |
| | do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] |
| |
|
| | image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) |
| |
|
| | |
| | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: |
| | self.final_offload_hook.offload() |
| |
|
| | if not return_dict: |
| | return (image, has_nsfw_concept) |
| |
|
| | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) |
| | return call |
| | pipe.call = new_call(pipe) |
| |
|
| | @torch.no_grad() |
| | def multistep_pre(self, noise_pred, t, x): |
| | step_span = len(t) |
| | bs = noise_pred.shape[0] |
| | bs_perstep = bs//step_span |
| |
|
| | denoised_latent = x |
| | for i, timestep in enumerate(t): |
| | curr_noise = noise_pred[i*bs_perstep:(i+1)*bs_perstep] |
| | denoised_latent = self.scheduler.step(curr_noise, timestep, denoised_latent)['prev_sample'] |
| | return denoised_latent |
| |
|
| | def register_t2v(model): |
| | def new_back(self): |
| | def backward_loop( |
| | latents, |
| | timesteps, |
| | prompt_embeds, |
| | guidance_scale, |
| | callback, |
| | callback_steps, |
| | num_warmup_steps, |
| | extra_step_kwargs, |
| | cross_attention_kwargs=None,): |
| | do_classifier_free_guidance = guidance_scale > 1.0 |
| | num_steps = (len(timesteps) - num_warmup_steps) // self.scheduler.order |
| | import time |
| | if num_steps<10: |
| | with self.progress_bar(total=num_steps) as progress_bar: |
| | for i, t in enumerate(timesteps): |
| | setattr(self.unet, 'order', i) |
| | |
| | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
| |
|
| | |
| | noise_pred = self.unet( |
| | latent_model_input, |
| | t, |
| | encoder_hidden_states=prompt_embeds, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | ).sample |
| |
|
| | |
| | if do_classifier_free_guidance: |
| | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
| |
|
| | |
| | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
| |
|
| | |
| | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
| | progress_bar.update() |
| | if callback is not None and i % callback_steps == 0: |
| | step_idx = i // getattr(self.scheduler, "order", 1) |
| | callback(step_idx, t, latents) |
| |
|
| | else: |
| | all_timesteps = len(timesteps) |
| | curr_step = 0 |
| | |
| | while curr_step<all_timesteps: |
| | register_time(self.unet, curr_step) |
| |
|
| | time_ls = [] |
| | time_ls.append(timesteps[curr_step]) |
| | curr_step += 1 |
| | cond = curr_step in [0,1,2,3,5,10,15,25,35] |
| | |
| | while (not cond) and (curr_step<all_timesteps): |
| | time_ls.append(timesteps[curr_step]) |
| | curr_step += 1 |
| | cond = curr_step in [0,1,2,3,5,10,15,25,35] |
| | |
| | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| | |
| | noise_pred = self.unet( |
| | latent_model_input, |
| | time_ls, |
| | encoder_hidden_states=prompt_embeds, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | ).sample |
| |
|
| | |
| | if do_classifier_free_guidance: |
| | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
| |
|
| | |
| | latents = multistep_pre(self, noise_pred, time_ls, latents) |
| | |
| | return latents.clone().detach() |
| | return backward_loop |
| | model.backward_loop = new_back(model) |
| | |
| |
|