Spaces:
Paused
Paused
| import imp | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import random | |
| from PIL import Image, ImageDraw, ImageFont | |
| import copy | |
| from typing import Optional, Union, Tuple, List, Callable, Dict, Any | |
| from tqdm.notebook import tqdm | |
| from diffusers.utils import BaseOutput, logging | |
| from diffusers.models.embeddings import TimestepEmbedding, Timesteps | |
| from diffusers.models.unet_2d_blocks import ( | |
| CrossAttnDownBlock2D, | |
| CrossAttnUpBlock2D, | |
| DownBlock2D, | |
| UNetMidBlock2DCrossAttn, | |
| UpBlock2D, | |
| get_down_block, | |
| get_up_block, | |
| ) | |
| from diffusers.models.unet_2d_condition import UNet2DConditionOutput, logger | |
| from copy import deepcopy | |
| import json | |
| import inspect | |
| import os | |
| import warnings | |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
| import numpy as np | |
| import PIL.Image | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer | |
| from diffusers.image_processor import VaeImageProcessor | |
| from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin | |
| from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel | |
| from diffusers.schedulers import KarrasDiffusionSchedulers | |
| from diffusers.utils.torch_utils import is_compiled_module | |
| from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput | |
| from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker | |
| from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel | |
| from tqdm import tqdm | |
| from controlnet_aux import HEDdetector, OpenposeDetector | |
| import time | |
| def seed_everything(seed): | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| def get_promptls(prompt_path): | |
| with open(prompt_path) as f: | |
| prompt_ls = json.load(f) | |
| prompt_ls = [prompt['caption'].replace('/','_') for prompt in prompt_ls] | |
| return prompt_ls | |
| def load_512(image_path, left=0, right=0, top=0, bottom=0): | |
| # print(image_path) | |
| if type(image_path) is str: | |
| image = np.array(Image.open(image_path)) | |
| if image.ndim>3: | |
| image = image[:,:,:3] | |
| elif image.ndim == 2: | |
| image = image.reshape(image.shape[0], image.shape[1],1).astype('uint8') | |
| else: | |
| image = image_path | |
| h, w, c = image.shape | |
| left = min(left, w-1) | |
| right = min(right, w - left - 1) | |
| top = min(top, h - left - 1) | |
| bottom = min(bottom, h - top - 1) | |
| image = image[top:h-bottom, left:w-right] | |
| h, w, c = image.shape | |
| if h < w: | |
| offset = (w - h) // 2 | |
| image = image[:, offset:offset + h] | |
| elif w < h: | |
| offset = (h - w) // 2 | |
| image = image[offset:offset + w] | |
| image = np.array(Image.fromarray(image).resize((512, 512))) | |
| return image | |
| def get_canny(image_path): | |
| image = load_512( | |
| image_path | |
| ) | |
| image = np.array(image) | |
| # get canny image | |
| image = cv2.Canny(image, 100, 200) | |
| image = image[:, :, None] | |
| image = np.concatenate([image, image, image], axis=2) | |
| canny_image = Image.fromarray(image) | |
| return canny_image | |
| def get_scribble(image_path, hed): | |
| image = load_512( | |
| image_path | |
| ) | |
| image = hed(image, scribble=True) | |
| return image | |
| def get_cocoimages(prompt_path): | |
| data_ls = [] | |
| with open(prompt_path) as f: | |
| prompt_ls = json.load(f) | |
| img_path = 'COCO2017-val/val2017' | |
| for prompt in tqdm(prompt_ls): | |
| caption = prompt['caption'].replace('/','_') | |
| image_id = str(prompt['image_id']) | |
| image_id = (12-len(image_id))*'0' + image_id+'.jpg' | |
| image_path = os.path.join(img_path, image_id) | |
| try: | |
| image = get_canny(image_path) | |
| except: | |
| continue | |
| curr_data = {'image':image, 'prompt':caption} | |
| data_ls.append(curr_data) | |
| return data_ls | |
| def get_cocoimages2(prompt_path): | |
| """scribble condition | |
| """ | |
| data_ls = [] | |
| with open(prompt_path) as f: | |
| prompt_ls = json.load(f) | |
| img_path = 'COCO2017-val/val2017' | |
| hed = HEDdetector.from_pretrained('ControlNet/detector_weights/annotator', filename='network-bsds500.pth') | |
| for prompt in tqdm(prompt_ls): | |
| caption = prompt['caption'].replace('/','_') | |
| image_id = str(prompt['image_id']) | |
| image_id = (12-len(image_id))*'0' + image_id+'.jpg' | |
| image_path = os.path.join(img_path, image_id) | |
| try: | |
| image = get_scribble(image_path,hed) | |
| except: | |
| continue | |
| curr_data = {'image':image, 'prompt':caption} | |
| data_ls.append(curr_data) | |
| return data_ls | |
| def warpped_feature(sample, step): | |
| """ | |
| sample: batch_size*dim*h*w, uncond: 0 - batch_size//2, cond: batch_size//2 - batch_size | |
| step: timestep span | |
| """ | |
| bs, dim, h, w = sample.shape | |
| uncond_fea, cond_fea = sample.chunk(2) | |
| uncond_fea = uncond_fea.repeat(step,1,1,1) # (step * bs//2) * dim * h *w | |
| cond_fea = cond_fea.repeat(step,1,1,1) # (step * bs//2) * dim * h *w | |
| return torch.cat([uncond_fea, cond_fea]) | |
| def warpped_skip_feature(block_samples, step): | |
| down_block_res_samples = [] | |
| for sample in block_samples: | |
| sample_expand = warpped_feature(sample, step) | |
| down_block_res_samples.append(sample_expand) | |
| return tuple(down_block_res_samples) | |
| def warpped_text_emb(text_emb, step): | |
| """ | |
| text_emb: batch_size*77*768, uncond: 0 - batch_size//2, cond: batch_size//2 - batch_size | |
| step: timestep span | |
| """ | |
| bs, token_len, dim = text_emb.shape | |
| uncond_fea, cond_fea = text_emb.chunk(2) | |
| uncond_fea = uncond_fea.repeat(step,1,1) # (step * bs//2) * 77 *768 | |
| cond_fea = cond_fea.repeat(step,1,1) # (step * bs//2) * 77 * 768 | |
| return torch.cat([uncond_fea, cond_fea]) # (step*bs) * 77 *768 | |
| def warpped_timestep(timesteps, bs): | |
| """ | |
| timestpes: list, such as [981, 961, 941] | |
| """ | |
| semi_bs = bs//2 | |
| ts = [] | |
| for timestep in timesteps: | |
| timestep = timestep[None] | |
| texp = timestep.expand(semi_bs) | |
| ts.append(texp) | |
| timesteps = torch.cat(ts) | |
| return timesteps.repeat(2,1).reshape(-1) | |
| def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): | |
| """ | |
| Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and | |
| Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 | |
| """ | |
| std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) | |
| std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) | |
| # rescale the results from guidance (fixes overexposure) | |
| noise_pred_rescaled = noise_cfg * (std_text / std_cfg) | |
| # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images | |
| noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg | |
| return noise_cfg | |
| def register_normal_pipeline(pipe): | |
| def new_call(self): | |
| def call( | |
| prompt: Union[str, List[str]] = None, | |
| height: Optional[int] = None, | |
| width: Optional[int] = None, | |
| num_inference_steps: int = 50, | |
| guidance_scale: float = 7.5, | |
| negative_prompt: Optional[Union[str, List[str]]] = None, | |
| num_images_per_prompt: Optional[int] = 1, | |
| eta: float = 0.0, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| latents: Optional[torch.FloatTensor] = None, | |
| prompt_embeds: Optional[torch.FloatTensor] = None, | |
| negative_prompt_embeds: Optional[torch.FloatTensor] = None, | |
| output_type: Optional[str] = "pil", | |
| return_dict: bool = True, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| guidance_rescale: float = 0.0, | |
| clip_skip: Optional[int] = None, | |
| callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, | |
| callback_on_step_end_tensor_inputs: List[str] = ["latents"], | |
| **kwargs, | |
| ): | |
| callback = kwargs.pop("callback", None) | |
| callback_steps = kwargs.pop("callback_steps", None) | |
| # 0. Default height and width to unet | |
| height = height or self.unet.config.sample_size * self.vae_scale_factor | |
| width = width or self.unet.config.sample_size * self.vae_scale_factor | |
| # to deal with lora scaling and other possible forward hooks | |
| # 1. Check inputs. Raise error if not correct | |
| self.check_inputs( | |
| prompt, | |
| height, | |
| width, | |
| callback_steps, | |
| negative_prompt, | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| callback_on_step_end_tensor_inputs, | |
| ) | |
| self._guidance_scale = guidance_scale | |
| self._guidance_rescale = guidance_rescale | |
| self._clip_skip = clip_skip | |
| self._cross_attention_kwargs = cross_attention_kwargs | |
| # 2. Define call parameters | |
| if prompt is not None and isinstance(prompt, str): | |
| batch_size = 1 | |
| elif prompt is not None and isinstance(prompt, list): | |
| batch_size = len(prompt) | |
| else: | |
| batch_size = prompt_embeds.shape[0] | |
| device = self._execution_device | |
| # 3. Encode input prompt | |
| lora_scale = ( | |
| self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None | |
| ) | |
| prompt_embeds, negative_prompt_embeds = self.encode_prompt( | |
| prompt, | |
| device, | |
| num_images_per_prompt, | |
| self.do_classifier_free_guidance, | |
| negative_prompt, | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| lora_scale=lora_scale, | |
| clip_skip=self.clip_skip, | |
| ) | |
| # For classifier free guidance, we need to do two forward passes. | |
| # Here we concatenate the unconditional and text embeddings into a single batch | |
| # to avoid doing two forward passes | |
| if self.do_classifier_free_guidance: | |
| prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) | |
| # 4. Prepare timesteps | |
| self.scheduler.set_timesteps(num_inference_steps, device=device) | |
| timesteps = self.scheduler.timesteps | |
| # 5. Prepare latent variables | |
| num_channels_latents = self.unet.config.in_channels | |
| latents = self.prepare_latents( | |
| batch_size * num_images_per_prompt, | |
| num_channels_latents, | |
| height, | |
| width, | |
| prompt_embeds.dtype, | |
| device, | |
| generator, | |
| latents, | |
| ) | |
| # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | |
| extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | |
| # 6.5 Optionally get Guidance Scale Embedding | |
| timestep_cond = None | |
| if self.unet.config.time_cond_proj_dim is not None: | |
| guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) | |
| timestep_cond = self.get_guidance_scale_embedding( | |
| guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim | |
| ).to(device=device, dtype=latents.dtype) | |
| # 7. Denoising loop | |
| num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | |
| self._num_timesteps = len(timesteps) | |
| init_latents = latents.detach().clone() | |
| with self.progress_bar(total=num_inference_steps) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| if t/1000 < 0.5: | |
| latents = latents + 0.003*init_latents | |
| setattr(self.unet, 'order', i) | |
| # expand the latents if we are doing classifier free guidance | |
| latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents | |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
| # predict the noise residual | |
| noise_pred = self.unet( | |
| latent_model_input, | |
| t, | |
| encoder_hidden_states=prompt_embeds, | |
| timestep_cond=timestep_cond, | |
| cross_attention_kwargs=self.cross_attention_kwargs, | |
| return_dict=False, | |
| )[0] | |
| # perform guidance | |
| if self.do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: | |
| # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf | |
| noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] | |
| if callback_on_step_end is not None: | |
| callback_kwargs = {} | |
| for k in callback_on_step_end_tensor_inputs: | |
| callback_kwargs[k] = locals()[k] | |
| callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) | |
| latents = callback_outputs.pop("latents", latents) | |
| prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) | |
| negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) | |
| # call the callback, if provided | |
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
| progress_bar.update() | |
| if callback is not None and i % callback_steps == 0: | |
| step_idx = i // getattr(self.scheduler, "order", 1) | |
| callback(step_idx, t, latents) | |
| if not output_type == "latent": | |
| image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ | |
| 0 | |
| ] | |
| image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) | |
| else: | |
| image = latents | |
| has_nsfw_concept = None | |
| if has_nsfw_concept is None: | |
| do_denormalize = [True] * image.shape[0] | |
| else: | |
| do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] | |
| image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) | |
| # Offload all models | |
| self.maybe_free_model_hooks() | |
| if not return_dict: | |
| return (image, has_nsfw_concept) | |
| return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) | |
| return call | |
| pipe.call = new_call(pipe) | |
| def register_parallel_pipeline(pipe): | |
| def new_call(self): | |
| def call( | |
| prompt: Union[str, List[str]] = None, | |
| height: Optional[int] = None, | |
| width: Optional[int] = None, | |
| num_inference_steps: int = 50, | |
| guidance_scale: float = 7.5, | |
| negative_prompt: Optional[Union[str, List[str]]] = None, | |
| num_images_per_prompt: Optional[int] = 1, | |
| eta: float = 0.0, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| latents: Optional[torch.FloatTensor] = None, | |
| prompt_embeds: Optional[torch.FloatTensor] = None, | |
| negative_prompt_embeds: Optional[torch.FloatTensor] = None, | |
| output_type: Optional[str] = "pil", | |
| return_dict: bool = True, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| guidance_rescale: float = 0.0, | |
| clip_skip: Optional[int] = None, | |
| callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, | |
| callback_on_step_end_tensor_inputs: List[str] = ["latents"], | |
| **kwargs, | |
| ): | |
| callback = kwargs.pop("callback", None) | |
| callback_steps = kwargs.pop("callback_steps", None) | |
| # 0. Default height and width to unet | |
| height = height or self.unet.config.sample_size * self.vae_scale_factor | |
| width = width or self.unet.config.sample_size * self.vae_scale_factor | |
| # to deal with lora scaling and other possible forward hooks | |
| # 1. Check inputs. Raise error if not correct | |
| self.check_inputs( | |
| prompt, | |
| height, | |
| width, | |
| callback_steps, | |
| negative_prompt, | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| callback_on_step_end_tensor_inputs, | |
| ) | |
| self._guidance_scale = guidance_scale | |
| self._guidance_rescale = guidance_rescale | |
| self._clip_skip = clip_skip | |
| self._cross_attention_kwargs = cross_attention_kwargs | |
| # 2. Define call parameters | |
| if prompt is not None and isinstance(prompt, str): | |
| batch_size = 1 | |
| elif prompt is not None and isinstance(prompt, list): | |
| batch_size = len(prompt) | |
| else: | |
| batch_size = prompt_embeds.shape[0] | |
| device = self._execution_device | |
| # 3. Encode input prompt | |
| lora_scale = ( | |
| self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None | |
| ) | |
| prompt_embeds, negative_prompt_embeds = self.encode_prompt( | |
| prompt, | |
| device, | |
| num_images_per_prompt, | |
| self.do_classifier_free_guidance, | |
| negative_prompt, | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| lora_scale=lora_scale, | |
| clip_skip=self.clip_skip, | |
| ) | |
| # For classifier free guidance, we need to do two forward passes. | |
| # Here we concatenate the unconditional and text embeddings into a single batch | |
| # to avoid doing two forward passes | |
| if self.do_classifier_free_guidance: | |
| prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) | |
| # 4. Prepare timesteps | |
| self.scheduler.set_timesteps(num_inference_steps, device=device) | |
| timesteps = self.scheduler.timesteps | |
| # 5. Prepare latent variables | |
| num_channels_latents = self.unet.config.in_channels | |
| latents = self.prepare_latents( | |
| batch_size * num_images_per_prompt, | |
| num_channels_latents, | |
| height, | |
| width, | |
| prompt_embeds.dtype, | |
| device, | |
| generator, | |
| latents, | |
| ) | |
| # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | |
| extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | |
| # 6.5 Optionally get Guidance Scale Embedding | |
| timestep_cond = None | |
| if self.unet.config.time_cond_proj_dim is not None: | |
| guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) | |
| timestep_cond = self.get_guidance_scale_embedding( | |
| guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim | |
| ).to(device=device, dtype=latents.dtype) | |
| # 7. Denoising loop | |
| num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | |
| self._num_timesteps = len(timesteps) | |
| init_latents = latents.detach().clone() | |
| #------------------------------------------------------- | |
| all_steps = len(self.scheduler.timesteps) | |
| curr_span = 1 | |
| curr_step = 0 | |
| # st = time.time() | |
| idx = 1 | |
| keytime = [0,1,2,3,5,10,15,25,35] | |
| keytime.append(all_steps) | |
| while curr_step<all_steps: | |
| refister_time(self.unet, curr_step) | |
| merge_span = curr_span | |
| if merge_span>0: | |
| time_ls = [] | |
| for i in range(curr_step, curr_step+merge_span): | |
| if i<all_steps: | |
| time_ls.append(self.scheduler.timesteps[i]) | |
| else: | |
| break | |
| ##-------------------------------- | |
| latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents | |
| # predict the noise residual | |
| noise_pred = self.unet( | |
| latent_model_input, | |
| time_ls, | |
| encoder_hidden_states=prompt_embeds, | |
| timestep_cond=timestep_cond, | |
| cross_attention_kwargs=self.cross_attention_kwargs, | |
| return_dict=False, | |
| )[0] | |
| # perform guidance | |
| if self.do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: | |
| # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf | |
| noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| step_span = len(time_ls) | |
| bs = noise_pred.shape[0] | |
| bs_perstep = bs//step_span | |
| denoised_latent = latents | |
| for i, timestep in enumerate(time_ls): | |
| if timestep/1000 < 0.5: | |
| denoised_latent = denoised_latent + 0.003*init_latents | |
| curr_noise = noise_pred[i*bs_perstep:(i+1)*bs_perstep] | |
| denoised_latent = self.scheduler.step(curr_noise, timestep, denoised_latent, **extra_step_kwargs, return_dict=False)[0] | |
| latents = denoised_latent | |
| ##---------------------------------------- | |
| curr_step += curr_span | |
| idx += 1 | |
| if curr_step<all_steps: | |
| curr_span = keytime[idx] - keytime[idx-1] | |
| if not output_type == "latent": | |
| image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ | |
| 0 | |
| ] | |
| image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) | |
| else: | |
| image = latents | |
| has_nsfw_concept = None | |
| if has_nsfw_concept is None: | |
| do_denormalize = [True] * image.shape[0] | |
| else: | |
| do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] | |
| image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) | |
| # Offload all models | |
| self.maybe_free_model_hooks() | |
| if not return_dict: | |
| return (image, has_nsfw_concept) | |
| return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) | |
| return call | |
| pipe.call = new_call(pipe) | |
| def register_faster_forward(model, mod = '50ls'): | |
| def faster_forward(self): | |
| def forward( | |
| sample: torch.FloatTensor, | |
| timestep: Union[torch.Tensor, float, int], | |
| encoder_hidden_states: torch.Tensor, | |
| class_labels: Optional[torch.Tensor] = None, | |
| timestep_cond: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, | |
| mid_block_additional_residual: Optional[torch.Tensor] = None, | |
| return_dict: bool = True, | |
| ) -> Union[UNet2DConditionOutput, Tuple]: | |
| r""" | |
| Args: | |
| sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor | |
| timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps | |
| encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states | |
| return_dict (`bool`, *optional*, defaults to `True`): | |
| Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. | |
| cross_attention_kwargs (`dict`, *optional*): | |
| A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | |
| `self.processor` in | |
| [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). | |
| Returns: | |
| [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: | |
| [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When | |
| returning a tuple, the first element is the sample tensor. | |
| """ | |
| # By default samples have to be AT least a multiple of the overall upsampling factor. | |
| # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). | |
| # However, the upsampling interpolation output size can be forced to fit any upsampling size | |
| # on the fly if necessary. | |
| default_overall_up_factor = 2**self.num_upsamplers | |
| # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` | |
| forward_upsample_size = False | |
| upsample_size = None | |
| if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): | |
| logger.info("Forward upsample size to force interpolation output size.") | |
| forward_upsample_size = True | |
| # prepare attention_mask | |
| if attention_mask is not None: | |
| attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 | |
| attention_mask = attention_mask.unsqueeze(1) | |
| # 0. center input if necessary | |
| if self.config.center_input_sample: | |
| sample = 2 * sample - 1.0 | |
| # 1. time | |
| if isinstance(timestep, list): | |
| timesteps = timestep[0] | |
| step = len(timestep) | |
| else: | |
| timesteps = timestep | |
| step = 1 | |
| if not torch.is_tensor(timesteps) and (not isinstance(timesteps,list)): | |
| # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can | |
| # This would be a good case for the `match` statement (Python 3.10+) | |
| is_mps = sample.device.type == "mps" | |
| if isinstance(timestep, float): | |
| dtype = torch.float32 if is_mps else torch.float64 | |
| else: | |
| dtype = torch.int32 if is_mps else torch.int64 | |
| timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) | |
| elif (not isinstance(timesteps,list)) and len(timesteps.shape) == 0: | |
| timesteps = timesteps[None].to(sample.device) | |
| if (not isinstance(timesteps,list)) and len(timesteps.shape) == 1: | |
| # broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
| timesteps = timesteps.expand(sample.shape[0]) | |
| elif isinstance(timesteps, list): | |
| #timesteps list, such as [981,961,941] | |
| timesteps = warpped_timestep(timesteps, sample.shape[0]).to(sample.device) | |
| t_emb = self.time_proj(timesteps) | |
| # `Timesteps` does not contain any weights and will always return f32 tensors | |
| # but time_embedding might actually be running in fp16. so we need to cast here. | |
| # there might be better ways to encapsulate this. | |
| t_emb = t_emb.to(dtype=self.dtype) | |
| emb = self.time_embedding(t_emb, timestep_cond) | |
| if self.class_embedding is not None: | |
| if class_labels is None: | |
| raise ValueError("class_labels should be provided when num_class_embeds > 0") | |
| if self.config.class_embed_type == "timestep": | |
| class_labels = self.time_proj(class_labels) | |
| # `Timesteps` does not contain any weights and will always return f32 tensors | |
| # there might be better ways to encapsulate this. | |
| class_labels = class_labels.to(dtype=sample.dtype) | |
| class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) | |
| if self.config.class_embeddings_concat: | |
| emb = torch.cat([emb, class_emb], dim=-1) | |
| else: | |
| emb = emb + class_emb | |
| if self.config.addition_embed_type == "text": | |
| aug_emb = self.add_embedding(encoder_hidden_states) | |
| emb = emb + aug_emb | |
| if self.time_embed_act is not None: | |
| emb = self.time_embed_act(emb) | |
| if self.encoder_hid_proj is not None: | |
| encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) | |
| #=============== | |
| order = self.order #timestep, start by 0 | |
| #=============== | |
| ipow = int(np.sqrt(9 + 8*order)) | |
| cond = order in [0, 1, 2, 3, 5, 10, 15, 25, 35] | |
| if isinstance(mod, int): | |
| cond = order % mod == 0 | |
| elif mod == "pro": | |
| cond = ipow * ipow == (9 + 8 * order) | |
| elif mod == "50ls": | |
| cond = order in [0, 1, 2, 3, 5, 10, 15, 25, 35] #40 #[0,1,2,3, 5, 10, 15] #[0, 1, 2, 3, 5, 10, 15, 25, 35, 40] | |
| elif mod == "50ls2": | |
| cond = order in [0, 10, 11, 12, 15, 20, 25, 30,35,45] #40 #[0,1,2,3, 5, 10, 15] #[0, 1, 2, 3, 5, 10, 15, 25, 35, 40] | |
| elif mod == "50ls3": | |
| cond = order in [0, 20, 25, 30,35,45,46,47,48,49] #40 #[0,1,2,3, 5, 10, 15] #[0, 1, 2, 3, 5, 10, 15, 25, 35, 40] | |
| elif mod == "50ls4": | |
| cond = order in [0, 9, 13, 14, 15, 28, 29, 32, 36,45] #40 #[0,1,2,3, 5, 10, 15] #[0, 1, 2, 3, 5, 10, 15, 25, 35, 40] | |
| elif mod == "100ls": | |
| cond = order > 85 or order < 10 or order % 5 == 0 | |
| elif mod == "75ls": | |
| cond = order > 65 or order < 10 or order % 5 == 0 | |
| elif mod == "s2": | |
| cond = order < 20 or order > 40 or order % 2 == 0 | |
| if cond: | |
| print(order) | |
| # 2. pre-process | |
| sample = self.conv_in(sample) | |
| # 3. down | |
| down_block_res_samples = (sample,) | |
| for downsample_block in self.down_blocks: | |
| if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: | |
| sample, res_samples = downsample_block( | |
| hidden_states=sample, | |
| temb=emb, | |
| encoder_hidden_states=encoder_hidden_states, | |
| attention_mask=attention_mask, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| ) | |
| else: | |
| sample, res_samples = downsample_block(hidden_states=sample, temb=emb) | |
| down_block_res_samples += res_samples | |
| if down_block_additional_residuals is not None: | |
| new_down_block_res_samples = () | |
| for down_block_res_sample, down_block_additional_residual in zip( | |
| down_block_res_samples, down_block_additional_residuals | |
| ): | |
| down_block_res_sample = down_block_res_sample + down_block_additional_residual | |
| new_down_block_res_samples += (down_block_res_sample,) | |
| down_block_res_samples = new_down_block_res_samples | |
| # 4. mid | |
| if self.mid_block is not None: | |
| sample = self.mid_block( | |
| sample, | |
| emb, | |
| encoder_hidden_states=encoder_hidden_states, | |
| attention_mask=attention_mask, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| ) | |
| if mid_block_additional_residual is not None: | |
| sample = sample + mid_block_additional_residual | |
| #----------------------save feature------------------------- | |
| # setattr(self, 'skip_feature', (tmp_sample.clone() for tmp_sample in down_block_res_samples)) | |
| setattr(self, 'skip_feature', deepcopy(down_block_res_samples)) | |
| setattr(self, 'toup_feature', sample.detach().clone()) | |
| #-----------------------save feature------------------------ | |
| #-------------------expand feature for parallel--------------- | |
| if isinstance(timestep, list): | |
| #timesteps list, such as [981,961,941] | |
| timesteps = warpped_timestep(timestep, sample.shape[0]).to(sample.device) | |
| t_emb = self.time_proj(timesteps) | |
| # `Timesteps` does not contain any weights and will always return f32 tensors | |
| # but time_embedding might actually be running in fp16. so we need to cast here. | |
| # there might be better ways to encapsulate this. | |
| t_emb = t_emb.to(dtype=self.dtype) | |
| emb = self.time_embedding(t_emb, timestep_cond) | |
| # print(emb.shape) | |
| # print(step, sample.shape) | |
| down_block_res_samples = warpped_skip_feature(down_block_res_samples, step) | |
| sample = warpped_feature(sample, step) | |
| # print(step, sample.shape) | |
| encoder_hidden_states = warpped_text_emb(encoder_hidden_states, step) | |
| # print(emb.shape) | |
| #-------------------expand feature for parallel--------------- | |
| else: | |
| down_block_res_samples = self.skip_feature | |
| sample = self.toup_feature | |
| #-------------------expand feature for parallel--------------- | |
| down_block_res_samples = warpped_skip_feature(down_block_res_samples, step) | |
| sample = warpped_feature(sample, step) | |
| encoder_hidden_states = warpped_text_emb(encoder_hidden_states, step) | |
| #-------------------expand feature for parallel--------------- | |
| # 5. up | |
| for i, upsample_block in enumerate(self.up_blocks): | |
| is_final_block = i == len(self.up_blocks) - 1 | |
| res_samples = down_block_res_samples[-len(upsample_block.resnets) :] | |
| down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] | |
| # if we have not reached the final block and need to forward the | |
| # upsample size, we do it here | |
| if not is_final_block and forward_upsample_size: | |
| upsample_size = down_block_res_samples[-1].shape[2:] | |
| if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: | |
| sample = upsample_block( | |
| hidden_states=sample, | |
| temb=emb, | |
| res_hidden_states_tuple=res_samples, | |
| encoder_hidden_states=encoder_hidden_states, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| upsample_size=upsample_size, | |
| attention_mask=attention_mask, | |
| ) | |
| else: | |
| sample = upsample_block( | |
| hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size | |
| ) | |
| # 6. post-process | |
| if self.conv_norm_out: | |
| sample = self.conv_norm_out(sample) | |
| sample = self.conv_act(sample) | |
| sample = self.conv_out(sample) | |
| if not return_dict: | |
| return (sample,) | |
| return UNet2DConditionOutput(sample=sample) | |
| return forward | |
| if model.__class__.__name__ == 'UNet2DConditionModel': | |
| model.forward = faster_forward(model) | |
| def register_normal_forward(model): | |
| def normal_forward(self): | |
| def forward( | |
| sample: torch.FloatTensor, | |
| timestep: Union[torch.Tensor, float, int], | |
| encoder_hidden_states: torch.Tensor, | |
| class_labels: Optional[torch.Tensor] = None, | |
| timestep_cond: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, | |
| mid_block_additional_residual: Optional[torch.Tensor] = None, | |
| return_dict: bool = True, | |
| ) -> Union[UNet2DConditionOutput, Tuple]: | |
| # By default samples have to be AT least a multiple of the overall upsampling factor. | |
| # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). | |
| # However, the upsampling interpolation output size can be forced to fit any upsampling size | |
| # on the fly if necessary. | |
| default_overall_up_factor = 2**self.num_upsamplers | |
| # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` | |
| forward_upsample_size = False | |
| upsample_size = None | |
| #--------------------- | |
| # import os | |
| # os.makedirs(f'{timestep.item()}_step', exist_ok=True) | |
| #--------------------- | |
| if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): | |
| logger.info("Forward upsample size to force interpolation output size.") | |
| forward_upsample_size = True | |
| # prepare attention_mask | |
| if attention_mask is not None: | |
| attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 | |
| attention_mask = attention_mask.unsqueeze(1) | |
| # 0. center input if necessary | |
| if self.config.center_input_sample: | |
| sample = 2 * sample - 1.0 | |
| # 1. time | |
| timesteps = timestep | |
| if not torch.is_tensor(timesteps): | |
| # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can | |
| # This would be a good case for the `match` statement (Python 3.10+) | |
| is_mps = sample.device.type == "mps" | |
| if isinstance(timestep, float): | |
| dtype = torch.float32 if is_mps else torch.float64 | |
| else: | |
| dtype = torch.int32 if is_mps else torch.int64 | |
| timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) | |
| elif len(timesteps.shape) == 0: | |
| timesteps = timesteps[None].to(sample.device) | |
| # broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
| timesteps = timesteps.expand(sample.shape[0]) | |
| t_emb = self.time_proj(timesteps) | |
| # `Timesteps` does not contain any weights and will always return f32 tensors | |
| # but time_embedding might actually be running in fp16. so we need to cast here. | |
| # there might be better ways to encapsulate this. | |
| t_emb = t_emb.to(dtype=self.dtype) | |
| emb = self.time_embedding(t_emb, timestep_cond) | |
| if self.class_embedding is not None: | |
| if class_labels is None: | |
| raise ValueError("class_labels should be provided when num_class_embeds > 0") | |
| if self.config.class_embed_type == "timestep": | |
| class_labels = self.time_proj(class_labels) | |
| # `Timesteps` does not contain any weights and will always return f32 tensors | |
| # there might be better ways to encapsulate this. | |
| class_labels = class_labels.to(dtype=sample.dtype) | |
| class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) | |
| if self.config.class_embeddings_concat: | |
| emb = torch.cat([emb, class_emb], dim=-1) | |
| else: | |
| emb = emb + class_emb | |
| if self.config.addition_embed_type == "text": | |
| aug_emb = self.add_embedding(encoder_hidden_states) | |
| emb = emb + aug_emb | |
| if self.time_embed_act is not None: | |
| emb = self.time_embed_act(emb) | |
| if self.encoder_hid_proj is not None: | |
| encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) | |
| # 2. pre-process | |
| sample = self.conv_in(sample) | |
| # 3. down | |
| down_block_res_samples = (sample,) | |
| for i, downsample_block in enumerate(self.down_blocks): | |
| if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: | |
| sample, res_samples = downsample_block( | |
| hidden_states=sample, | |
| temb=emb, | |
| encoder_hidden_states=encoder_hidden_states, | |
| attention_mask=attention_mask, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| ) | |
| else: | |
| sample, res_samples = downsample_block(hidden_states=sample, temb=emb) | |
| #--------------------------------- | |
| # torch.save(sample, f'{timestep.item()}_step/down_{i}.pt') | |
| #---------------------------------- | |
| down_block_res_samples += res_samples | |
| if down_block_additional_residuals is not None: | |
| new_down_block_res_samples = () | |
| for down_block_res_sample, down_block_additional_residual in zip( | |
| down_block_res_samples, down_block_additional_residuals | |
| ): | |
| down_block_res_sample = down_block_res_sample + down_block_additional_residual | |
| new_down_block_res_samples += (down_block_res_sample,) | |
| down_block_res_samples = new_down_block_res_samples | |
| # 4. mid | |
| if self.mid_block is not None: | |
| sample = self.mid_block( | |
| sample, | |
| emb, | |
| encoder_hidden_states=encoder_hidden_states, | |
| attention_mask=attention_mask, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| ) | |
| # torch.save(sample, f'{timestep.item()}_step/mid.pt') | |
| if mid_block_additional_residual is not None: | |
| sample = sample + mid_block_additional_residual | |
| # 5. up | |
| for i, upsample_block in enumerate(self.up_blocks): | |
| is_final_block = i == len(self.up_blocks) - 1 | |
| res_samples = down_block_res_samples[-len(upsample_block.resnets) :] | |
| down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] | |
| # if we have not reached the final block and need to forward the | |
| # upsample size, we do it here | |
| if not is_final_block and forward_upsample_size: | |
| upsample_size = down_block_res_samples[-1].shape[2:] | |
| if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: | |
| sample = upsample_block( | |
| hidden_states=sample, | |
| temb=emb, | |
| res_hidden_states_tuple=res_samples, | |
| encoder_hidden_states=encoder_hidden_states, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| upsample_size=upsample_size, | |
| attention_mask=attention_mask, | |
| ) | |
| else: | |
| sample = upsample_block( | |
| hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size | |
| ) | |
| #---------------------------- | |
| # torch.save(sample, f'{timestep.item()}_step/up_{i}.pt') | |
| #---------------------------- | |
| # 6. post-process | |
| if self.conv_norm_out: | |
| sample = self.conv_norm_out(sample) | |
| sample = self.conv_act(sample) | |
| sample = self.conv_out(sample) | |
| if not return_dict: | |
| return (sample,) | |
| return UNet2DConditionOutput(sample=sample) | |
| return forward | |
| if model.__class__.__name__ == 'UNet2DConditionModel': | |
| model.forward = normal_forward(model) | |
| def refister_time(unet, t): | |
| setattr(unet, 'order', t) | |
| def register_controlnet_pipeline2(pipe): | |
| def new_call(self): | |
| # @replace_example_docstring(EXAMPLE_DOC_STRING) | |
| def call( | |
| prompt: Union[str, List[str]] = None, | |
| image: Union[ | |
| torch.FloatTensor, | |
| PIL.Image.Image, | |
| np.ndarray, | |
| List[torch.FloatTensor], | |
| List[PIL.Image.Image], | |
| List[np.ndarray], | |
| ] = None, | |
| height: Optional[int] = None, | |
| width: Optional[int] = None, | |
| num_inference_steps: int = 50, | |
| guidance_scale: float = 7.5, | |
| negative_prompt: Optional[Union[str, List[str]]] = None, | |
| num_images_per_prompt: Optional[int] = 1, | |
| eta: float = 0.0, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| latents: Optional[torch.FloatTensor] = None, | |
| prompt_embeds: Optional[torch.FloatTensor] = None, | |
| negative_prompt_embeds: Optional[torch.FloatTensor] = None, | |
| output_type: Optional[str] = "pil", | |
| return_dict: bool = True, | |
| callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | |
| callback_steps: int = 1, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| controlnet_conditioning_scale: Union[float, List[float]] = 1.0, | |
| guess_mode: bool = False, | |
| ): | |
| # 1. Check inputs. Raise error if not correct | |
| self.check_inputs( | |
| prompt, | |
| image, | |
| callback_steps, | |
| negative_prompt, | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| controlnet_conditioning_scale, | |
| ) | |
| # 2. Define call parameters | |
| if prompt is not None and isinstance(prompt, str): | |
| batch_size = 1 | |
| elif prompt is not None and isinstance(prompt, list): | |
| batch_size = len(prompt) | |
| else: | |
| batch_size = prompt_embeds.shape[0] | |
| device = self._execution_device | |
| # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | |
| # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | |
| # corresponds to doing no classifier free guidance. | |
| do_classifier_free_guidance = guidance_scale > 1.0 | |
| controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet | |
| if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): | |
| controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) | |
| global_pool_conditions = ( | |
| controlnet.config.global_pool_conditions | |
| if isinstance(controlnet, ControlNetModel) | |
| else controlnet.nets[0].config.global_pool_conditions | |
| ) | |
| guess_mode = guess_mode or global_pool_conditions | |
| # 3. Encode input prompt | |
| text_encoder_lora_scale = ( | |
| cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None | |
| ) | |
| prompt_embeds = self._encode_prompt( | |
| prompt, | |
| device, | |
| num_images_per_prompt, | |
| do_classifier_free_guidance, | |
| negative_prompt, | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| lora_scale=text_encoder_lora_scale, | |
| ) | |
| # 4. Prepare image | |
| if isinstance(controlnet, ControlNetModel): | |
| image = self.prepare_image( | |
| image=image, | |
| width=width, | |
| height=height, | |
| batch_size=batch_size * num_images_per_prompt, | |
| num_images_per_prompt=num_images_per_prompt, | |
| device=device, | |
| dtype=controlnet.dtype, | |
| do_classifier_free_guidance=do_classifier_free_guidance, | |
| guess_mode=guess_mode, | |
| ) | |
| height, width = image.shape[-2:] | |
| elif isinstance(controlnet, MultiControlNetModel): | |
| images = [] | |
| for image_ in image: | |
| image_ = self.prepare_image( | |
| image=image_, | |
| width=width, | |
| height=height, | |
| batch_size=batch_size * num_images_per_prompt, | |
| num_images_per_prompt=num_images_per_prompt, | |
| device=device, | |
| dtype=controlnet.dtype, | |
| do_classifier_free_guidance=do_classifier_free_guidance, | |
| guess_mode=guess_mode, | |
| ) | |
| images.append(image_) | |
| image = images | |
| height, width = image[0].shape[-2:] | |
| else: | |
| assert False | |
| # 5. Prepare timesteps | |
| self.scheduler.set_timesteps(num_inference_steps, device=device) | |
| timesteps = self.scheduler.timesteps | |
| # 6. Prepare latent variables | |
| num_channels_latents = self.unet.config.in_channels | |
| latents = self.prepare_latents( | |
| batch_size * num_images_per_prompt, | |
| num_channels_latents, | |
| height, | |
| width, | |
| prompt_embeds.dtype, | |
| device, | |
| generator, | |
| latents, | |
| ) | |
| self.init_latent = latents.detach().clone() | |
| # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | |
| extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | |
| # 8. Denoising loop | |
| #------------------------------------------------------------- | |
| all_steps = len(self.scheduler.timesteps) | |
| curr_span = 1 | |
| curr_step = 0 | |
| # st = time.time() | |
| idx = 1 | |
| keytime = [0,1,2,3,5,10,15,25,35,50] | |
| while curr_step<all_steps: | |
| # torch.cuda.empty_cache() | |
| # print(curr_step) | |
| refister_time(self.unet, curr_step) | |
| merge_span = curr_span | |
| if merge_span>0: | |
| time_ls = [] | |
| for i in range(curr_step, curr_step+merge_span): | |
| if i<all_steps: | |
| time_ls.append(self.scheduler.timesteps[i]) | |
| else: | |
| break | |
| # torch.cuda.empty_cache() | |
| ##-------------------------------- | |
| latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, time_ls[0]) | |
| if curr_step in [0,1,2,3,5,10,15,25,35]: | |
| # controlnet(s) inference | |
| control_model_input = latent_model_input | |
| controlnet_prompt_embeds = prompt_embeds | |
| down_block_res_samples, mid_block_res_sample = self.controlnet( | |
| control_model_input, | |
| time_ls[0], | |
| encoder_hidden_states=controlnet_prompt_embeds, | |
| controlnet_cond=image, | |
| conditioning_scale=controlnet_conditioning_scale, | |
| guess_mode=guess_mode, | |
| return_dict=False, | |
| ) | |
| #----------------------save controlnet feature------------------------- | |
| #useless, shoule delete | |
| # setattr(self, 'downres_samples', deepcopy(down_block_res_samples)) | |
| # setattr(self, 'midres_sample', mid_block_res_sample.detach().clone()) | |
| #-----------------------save controlnet feature------------------------ | |
| else: | |
| down_block_res_samples = None #self.downres_samples | |
| mid_block_res_sample = None #self.midres_sample | |
| # predict the noise residual | |
| noise_pred = self.unet( | |
| latent_model_input, | |
| time_ls, | |
| encoder_hidden_states=prompt_embeds, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| down_block_additional_residuals=down_block_res_samples, | |
| mid_block_additional_residual=mid_block_res_sample, | |
| return_dict=False, | |
| )[0] | |
| # perform guidance | |
| if do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| if isinstance(time_ls, list): | |
| step_span = len(time_ls) | |
| bs = noise_pred.shape[0] | |
| bs_perstep = bs//step_span | |
| denoised_latent = latents | |
| for i, timestep in enumerate(time_ls): | |
| curr_noise = noise_pred[i*bs_perstep:(i+1)*bs_perstep] | |
| denoised_latent = self.scheduler.step(curr_noise, timestep, denoised_latent, **extra_step_kwargs, return_dict=False)[0] | |
| latents = denoised_latent | |
| ##---------------------------------------- | |
| curr_step += curr_span | |
| idx += 1 | |
| if curr_step<all_steps: | |
| curr_span = keytime[idx] - keytime[idx-1] | |
| # for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="Sampling")): | |
| #------------------------------------------------------------- | |
| # If we do sequential model offloading, let's offload unet and controlnet | |
| # manually for max memory savings | |
| if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: | |
| self.unet.to("cpu") | |
| self.controlnet.to("cpu") | |
| torch.cuda.empty_cache() | |
| if not output_type == "latent": | |
| image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] | |
| image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) | |
| else: | |
| image = latents | |
| has_nsfw_concept = None | |
| if has_nsfw_concept is None: | |
| do_denormalize = [True] * image.shape[0] | |
| else: | |
| do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] | |
| image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) | |
| # Offload last model to CPU | |
| if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: | |
| self.final_offload_hook.offload() | |
| if not return_dict: | |
| return (image, has_nsfw_concept) | |
| return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) | |
| return call | |
| pipe.call = new_call(pipe) | |
| def multistep_pre(self, noise_pred, t, x): | |
| step_span = len(t) | |
| bs = noise_pred.shape[0] | |
| bs_perstep = bs//step_span | |
| denoised_latent = x | |
| for i, timestep in enumerate(t): | |
| curr_noise = noise_pred[i*bs_perstep:(i+1)*bs_perstep] | |
| denoised_latent = self.scheduler.step(curr_noise, timestep, denoised_latent)['prev_sample'] | |
| return denoised_latent | |
| def register_t2v(model): | |
| def new_back(self): | |
| def backward_loop( | |
| latents, | |
| timesteps, | |
| prompt_embeds, | |
| guidance_scale, | |
| callback, | |
| callback_steps, | |
| num_warmup_steps, | |
| extra_step_kwargs, | |
| cross_attention_kwargs=None,): | |
| do_classifier_free_guidance = guidance_scale > 1.0 | |
| num_steps = (len(timesteps) - num_warmup_steps) // self.scheduler.order | |
| import time | |
| if num_steps<10: | |
| with self.progress_bar(total=num_steps) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| setattr(self.unet, 'order', i) | |
| # expand the latents if we are doing classifier free guidance | |
| latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
| # predict the noise residual | |
| noise_pred = self.unet( | |
| latent_model_input, | |
| t, | |
| encoder_hidden_states=prompt_embeds, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| ).sample | |
| # perform guidance | |
| if do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | |
| # call the callback, if provided | |
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
| progress_bar.update() | |
| if callback is not None and i % callback_steps == 0: | |
| step_idx = i // getattr(self.scheduler, "order", 1) | |
| callback(step_idx, t, latents) | |
| else: | |
| all_timesteps = len(timesteps) | |
| curr_step = 0 | |
| while curr_step<all_timesteps: | |
| refister_time(self.unet, curr_step) | |
| time_ls = [] | |
| time_ls.append(timesteps[curr_step]) | |
| curr_step += 1 | |
| cond = curr_step in [0,1,2,3,5,10,15,25,35] | |
| while (not cond) and (curr_step<all_timesteps): | |
| time_ls.append(timesteps[curr_step]) | |
| curr_step += 1 | |
| cond = curr_step in [0,1,2,3,5,10,15,25,35] | |
| latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
| # predict the noise residual | |
| noise_pred = self.unet( | |
| latent_model_input, | |
| time_ls, | |
| encoder_hidden_states=prompt_embeds, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| ).sample | |
| # perform guidance | |
| if do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = multistep_pre(self, noise_pred, time_ls, latents) | |
| return latents.clone().detach() | |
| return backward_loop | |
| model.backward_loop = new_back(model) | |