from dataclasses import dataclass from typing import Union, List, Optional import PIL.Image import numpy as np from tqdm.auto import trange from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import * from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline from diffusers import DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EulerDiscreteScheduler from diffusers.utils.outputs import BaseOutput from modules.layerdiffuse.vae import TransparentVAEDecoder, TransparentVAEEncoder, vae_encode from .layerdiff3d import UNetFrameConditionModel from utils.torch_utils import seed_everything, img2tensor, tensor2img @dataclass class LayerdiffPipelineOutput(BaseOutput): """ Output class for Stable Diffusion pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. """ images: Union[List[PIL.Image.Image], np.ndarray] vis_list: Union[List[PIL.Image.Image], np.ndarray] @torch.no_grad() def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, show_progress=True, c_concat=None): """DPM-Solver++(2M).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() t_fn = lambda sigma: sigma.log().neg() old_denoised = None for i in trange(len(sigmas) - 1, disable=not show_progress): model_input = x denoised = model(model_input, sigmas[i] * s_in, c_concat=c_concat, **extra_args) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) h = t_next - t if old_denoised is None or sigmas[i + 1] == 0: x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised else: h_last = t - t_fn(sigmas[i - 1]) r = h_last / h denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d old_denoised = denoised return x class KDiffusionStableDiffusionXLPipeline(StableDiffusionXLImg2ImgPipeline): _optional_components = [ "tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "image_encoder", "feature_extractor", ] def __init__(self, vae, text_encoder, tokenizer, text_encoder_2, tokenizer_2, unet, scheduler=None, trans_vae=None, tag_list=None, image_encoder: CLIPVisionModelWithProjection = None, feature_extractor: CLIPImageProcessor = None, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, ): if scheduler is None: config_min = {"final_sigmas_type":"sigma_min"} config_min_euler = {"final_sigmas_type":"sigma_min", "euler_at_final": True } config_zero = {"final_sigmas_type":"zero"} schedulers = { "DPMPP_2M": { "min": (DPMSolverMultistepScheduler, config_min), "min_euler": (DPMSolverMultistepScheduler, config_min_euler), "zero": (DPMSolverMultistepScheduler, config_zero), }, "DPMPP_2M_K": { "min": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, **config_min}), "min_euler": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, **config_min_euler}), "zero": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, **config_zero}), }, "DPMPP_2M_SDE": { "min": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", **config_min}), "min_euler": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", **config_min_euler}), "zero": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", **config_zero}), }, "DPMPP_2M_SDE_K": { "min": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "use_karras_sigmas": True, **config_min}), "min_euler": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "use_karras_sigmas": True, **config_min_euler}), "zero": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++", **config_zero}), }, "DPMPP": { "min": (DPMSolverSinglestepScheduler, config_min), "min_euler": (DPMSolverSinglestepScheduler, config_min_euler), "zero": (DPMSolverSinglestepScheduler, config_zero), }, "DPMPP_K": { "min": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True, **config_min}), "min_euler": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True, **config_min_euler}), "zero": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True, **config_zero}), }, } model_id = "frankjoshua/juggernautXL_version6Rundiffusion" scheduler_name = "DPMPP_2M_SDE" scheduler_config_name = "zero" scheduler_configs = schedulers[scheduler_name] scheduler = scheduler_configs[scheduler_config_name][0].from_pretrained( model_id, subfolder="scheduler", **scheduler_configs[scheduler_config_name][1], ) super().__init__( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler,feature_extractor=feature_extractor, image_encoder=image_encoder, requires_aesthetics_score=requires_aesthetics_score, force_zeros_for_empty_prompt=force_zeros_for_empty_prompt, add_watermarker=add_watermarker) # self.register_to_config(tag_list=tag_list) self.register_modules(trans_vae=trans_vae) @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @torch.inference_mode() def encode_cropped_prompt_77tokens(self, prompt: str): device = self.text_encoder.device tokenizers = [self.tokenizer, self.tokenizer_2] text_encoders = [self.text_encoder, self.text_encoder_2] pooled_prompt_embeds = None prompt_embeds_list = [] for tokenizer, text_encoder in zip(tokenizers, text_encoders): text_input_ids = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ).input_ids prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True, return_dict=False) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[-1][-2] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1).to(dtype=self.unet.dtype, device=device) pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) # prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) return prompt_embeds, pooled_prompt_embeds def denoise_func(self, latents, add_text_embeds, add_time_ids, prompt_embeds, c_concat, num_inference_steps=50): # 4. Prepare timesteps device = self.unet.device timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps=None, sigmas=None ) latents = latents * self.scheduler.init_noise_sigma for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} noise_pred = self.unet( torch.cat([latent_model_input, c_concat], dim=-3), t, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://huggingface.co/papers/2305.08891 noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) return latents @torch.inference_mode() def __call__( self, initial_latent: torch.FloatTensor = None, strength: float = 1.0, num_inference_steps: int = 25, guidance_scale: float = 5.0, batch_size: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, c_concat=None, prompt=None, negative_prompt=None, show_progress=True, fullpage=None, group_index=None ): device = self.unet.device dtype = self.unet.dtype if fullpage is not None: page_alpha = img2tensor(fullpage[..., -1] / 255., device=self.vae.device, dtype=self.vae.dtype)[0][..., None] fullpage = fullpage[..., :3] c_concat = np.concatenate([np.full_like(fullpage[..., :1], fill_value=255), fullpage], axis=2) c_concat = img2tensor(c_concat, normalize=True) c_concat = vae_encode(self.vae, self.trans_vae.encoder, c_concat, use_offset=False).to(device=device, dtype=dtype) c_concat = c_concat.to(dtype=dtype) assert c_concat is not None self._guidance_scale = guidance_scale is_3d = isinstance(self.unet, UNetFrameConditionModel) lh, lw = c_concat.shape[-2:] num_frames = 1 if is_3d: if prompt is not None: num_frames = len(prompt) if prompt_embeds is not None: num_frames = len(prompt_embeds) if initial_latent is None: initial_latent = torch.zeros((batch_size, 4, lh, lw), device=self.unet.device, dtype=self.unet.dtype) if is_3d and c_concat.ndim == 4: c_concat = c_concat[:, None].expand(-1, num_frames, -1, -1, -1) if is_3d and initial_latent.ndim == 4: initial_latent = initial_latent[:, None].expand(-1, num_frames, -1, -1, -1) if prompt is not None: prompt_embeds, pooled_prompt_embeds = self.encode_cropped_prompt_77tokens(prompt) if negative_prompt is not None and self.do_classifier_free_guidance: negative_prompt_embeds, negative_pooled_prompt_embeds = self.encode_cropped_prompt_77tokens(negative_prompt) # Initial latents # noise = randn_tensor(initial_latent.shape, generator=generator, device=device, dtype=self.unet.dtype) noise = randn_tensor(initial_latent[:, [0]].shape, generator=generator, device=device, dtype=self.unet.dtype).expand(-1, num_frames, -1, -1, -1) # latents = initial_latent.to(noise) + noise * sigmas[0].to(noise) height = lh * self.vae_scale_factor width = lw * self.vae_scale_factor add_time_ids = list((height, width) + (0, 0) + (height, width)) add_time_ids = torch.tensor([add_time_ids], dtype=self.unet.dtype) add_time_ids = add_time_ids.expand((prompt_embeds.shape[0], -1)) add_neg_time_ids = add_time_ids.clone() # Batch # latents = latents.to(device) add_time_ids = add_time_ids.repeat(batch_size, 1).to(device) add_neg_time_ids = add_neg_time_ids.repeat(batch_size, 1).to(device) prompt_embeds = prompt_embeds.repeat(batch_size, 1, 1).to(device) pooled_prompt_embeds = pooled_prompt_embeds.repeat(batch_size, 1).to(device) sampler_kwargs = dict( cfg_scale=guidance_scale, positive=dict( encoder_hidden_states=prompt_embeds, added_cond_kwargs={"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids},) ) if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.repeat(batch_size, 1, 1).to(device) negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(batch_size, 1).to(device) sampler_kwargs['negative'] = dict( encoder_hidden_states=negative_prompt_embeds, added_cond_kwargs={"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_neg_time_ids}, ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps=None, sigmas=None ) latents = noise * self.scheduler.init_noise_sigma for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids} noise_pred = self.unet( torch.cat([latent_model_input, c_concat], dim=-3), t, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs, return_dict=False, group_index=group_index )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://huggingface.co/papers/2305.08891 noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if latents.ndim == 5: latents = latents[0] if self.trans_vae is None: return latents latents = latents.to(dtype=self.trans_vae.dtype, device=self.trans_vae.device) / self.vae.config.scaling_factor vis_list = [] res_list = [] for latent in latents: latent = latent[None] # latent = scheduler.add_noise(latent, torch.randn_like(latent), timesteps=torch.tensor([1], device=latent.device)) result_list, vis_list_batch = self.trans_vae.decoder(self.vae, latent, mask=page_alpha) vis_list += vis_list_batch res_list += result_list return LayerdiffPipelineOutput(images=res_list, vis_list=vis_list)