| from typing import Literal, Union, Optional, Tuple, List |
|
|
| import torch |
| from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection |
| from diffusers import ( |
| UNet2DConditionModel, |
| SchedulerMixin, |
| StableDiffusionPipeline, |
| StableDiffusionXLPipeline, |
| AutoencoderKL, |
| ) |
| from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( |
| convert_ldm_unet_checkpoint, |
| ) |
| from safetensors.torch import load_file |
| from diffusers.schedulers import ( |
| DDIMScheduler, |
| DDPMScheduler, |
| LMSDiscreteScheduler, |
| EulerDiscreteScheduler, |
| EulerAncestralDiscreteScheduler, |
| UniPCMultistepScheduler, |
| ) |
|
|
| from omegaconf import OmegaConf |
|
|
| |
| NUM_TRAIN_TIMESTEPS = 1000 |
| BETA_START = 0.00085 |
| BETA_END = 0.0120 |
|
|
| UNET_PARAMS_MODEL_CHANNELS = 320 |
| UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] |
| UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] |
| UNET_PARAMS_IMAGE_SIZE = 64 |
| UNET_PARAMS_IN_CHANNELS = 4 |
| UNET_PARAMS_OUT_CHANNELS = 4 |
| UNET_PARAMS_NUM_RES_BLOCKS = 2 |
| UNET_PARAMS_CONTEXT_DIM = 768 |
| UNET_PARAMS_NUM_HEADS = 8 |
| |
|
|
| VAE_PARAMS_Z_CHANNELS = 4 |
| VAE_PARAMS_RESOLUTION = 256 |
| VAE_PARAMS_IN_CHANNELS = 3 |
| VAE_PARAMS_OUT_CH = 3 |
| VAE_PARAMS_CH = 128 |
| VAE_PARAMS_CH_MULT = [1, 2, 4, 4] |
| VAE_PARAMS_NUM_RES_BLOCKS = 2 |
|
|
| |
| V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] |
| V2_UNET_PARAMS_CONTEXT_DIM = 1024 |
| |
|
|
| TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4" |
| TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1" |
|
|
| AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a", "euler", "uniPC"] |
|
|
| SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection] |
|
|
| DIFFUSERS_CACHE_DIR = None |
|
|
|
|
| def load_checkpoint_with_text_encoder_conversion(ckpt_path: str, device="cpu"): |
| |
| TEXT_ENCODER_KEY_REPLACEMENTS = [ |
| ( |
| "cond_stage_model.transformer.embeddings.", |
| "cond_stage_model.transformer.text_model.embeddings.", |
| ), |
| ( |
| "cond_stage_model.transformer.encoder.", |
| "cond_stage_model.transformer.text_model.encoder.", |
| ), |
| ( |
| "cond_stage_model.transformer.final_layer_norm.", |
| "cond_stage_model.transformer.text_model.final_layer_norm.", |
| ), |
| ] |
|
|
| if ckpt_path.endswith(".safetensors"): |
| checkpoint = None |
| state_dict = load_file(ckpt_path) |
| else: |
| checkpoint = torch.load(ckpt_path, map_location=device) |
| if "state_dict" in checkpoint: |
| state_dict = checkpoint["state_dict"] |
| else: |
| state_dict = checkpoint |
| checkpoint = None |
|
|
| key_reps = [] |
| for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: |
| for key in state_dict.keys(): |
| if key.startswith(rep_from): |
| new_key = rep_to + key[len(rep_from) :] |
| key_reps.append((key, new_key)) |
|
|
| for key, new_key in key_reps: |
| state_dict[new_key] = state_dict[key] |
| del state_dict[key] |
|
|
| return checkpoint, state_dict |
|
|
|
|
| def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False): |
| """ |
| Creates a config for the diffusers based on the config of the LDM model. |
| """ |
| |
|
|
| block_out_channels = [ |
| UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT |
| ] |
|
|
| down_block_types = [] |
| resolution = 1 |
| for i in range(len(block_out_channels)): |
| block_type = ( |
| "CrossAttnDownBlock2D" |
| if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS |
| else "DownBlock2D" |
| ) |
| down_block_types.append(block_type) |
| if i != len(block_out_channels) - 1: |
| resolution *= 2 |
|
|
| up_block_types = [] |
| for i in range(len(block_out_channels)): |
| block_type = ( |
| "CrossAttnUpBlock2D" |
| if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS |
| else "UpBlock2D" |
| ) |
| up_block_types.append(block_type) |
| resolution //= 2 |
|
|
| config = dict( |
| sample_size=UNET_PARAMS_IMAGE_SIZE, |
| in_channels=UNET_PARAMS_IN_CHANNELS, |
| out_channels=UNET_PARAMS_OUT_CHANNELS, |
| down_block_types=tuple(down_block_types), |
| up_block_types=tuple(up_block_types), |
| block_out_channels=tuple(block_out_channels), |
| layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, |
| cross_attention_dim=UNET_PARAMS_CONTEXT_DIM |
| if not v2 |
| else V2_UNET_PARAMS_CONTEXT_DIM, |
| attention_head_dim=UNET_PARAMS_NUM_HEADS |
| if not v2 |
| else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, |
| |
| ) |
| if v2 and use_linear_projection_in_v2: |
| config["use_linear_projection"] = True |
|
|
| return config |
|
|
|
|
| def load_diffusers_model( |
| pretrained_model_name_or_path: str, |
| v2: bool = False, |
| clip_skip: Optional[int] = None, |
| weight_dtype: torch.dtype = torch.float32, |
| ) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: |
| if v2: |
| tokenizer = CLIPTokenizer.from_pretrained( |
| TOKENIZER_V2_MODEL_NAME, |
| subfolder="tokenizer", |
| torch_dtype=weight_dtype, |
| cache_dir=DIFFUSERS_CACHE_DIR, |
| ) |
| text_encoder = CLIPTextModel.from_pretrained( |
| pretrained_model_name_or_path, |
| subfolder="text_encoder", |
| |
| num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23, |
| torch_dtype=weight_dtype, |
| cache_dir=DIFFUSERS_CACHE_DIR, |
| ) |
| else: |
| tokenizer = CLIPTokenizer.from_pretrained( |
| TOKENIZER_V1_MODEL_NAME, |
| subfolder="tokenizer", |
| torch_dtype=weight_dtype, |
| cache_dir=DIFFUSERS_CACHE_DIR, |
| ) |
| text_encoder = CLIPTextModel.from_pretrained( |
| pretrained_model_name_or_path, |
| subfolder="text_encoder", |
| num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12, |
| torch_dtype=weight_dtype, |
| cache_dir=DIFFUSERS_CACHE_DIR, |
| ) |
|
|
| unet = UNet2DConditionModel.from_pretrained( |
| pretrained_model_name_or_path, |
| subfolder="unet", |
| torch_dtype=weight_dtype, |
| cache_dir=DIFFUSERS_CACHE_DIR, |
| ) |
|
|
| vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") |
|
|
| return tokenizer, text_encoder, unet, vae |
|
|
|
|
| def load_checkpoint_model( |
| checkpoint_path: str, |
| v2: bool = False, |
| clip_skip: Optional[int] = None, |
| weight_dtype: torch.dtype = torch.float32, |
| ) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: |
| pipe = StableDiffusionPipeline.from_single_file( |
| checkpoint_path, |
| upcast_attention=True if v2 else False, |
| torch_dtype=weight_dtype, |
| cache_dir=DIFFUSERS_CACHE_DIR, |
| ) |
|
|
| _, state_dict = load_checkpoint_with_text_encoder_conversion(checkpoint_path) |
| unet_config = create_unet_diffusers_config(v2, use_linear_projection_in_v2=v2) |
| unet_config["class_embed_type"] = None |
| unet_config["addition_embed_type"] = None |
| converted_unet_checkpoint = convert_ldm_unet_checkpoint(state_dict, unet_config) |
| unet = UNet2DConditionModel(**unet_config) |
| unet.load_state_dict(converted_unet_checkpoint) |
|
|
| tokenizer = pipe.tokenizer |
| text_encoder = pipe.text_encoder |
| vae = pipe.vae |
| if clip_skip is not None: |
| if v2: |
| text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1) |
| else: |
| text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1) |
|
|
| del pipe |
|
|
| return tokenizer, text_encoder, unet, vae |
|
|
|
|
| def load_models( |
| pretrained_model_name_or_path: str, |
| scheduler_name: str, |
| v2: bool = False, |
| v_pred: bool = False, |
| weight_dtype: torch.dtype = torch.float32, |
| ) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]: |
| if pretrained_model_name_or_path.endswith( |
| ".ckpt" |
| ) or pretrained_model_name_or_path.endswith(".safetensors"): |
| tokenizer, text_encoder, unet, vae = load_checkpoint_model( |
| pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype |
| ) |
| else: |
| tokenizer, text_encoder, unet, vae = load_diffusers_model( |
| pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype |
| ) |
|
|
| if scheduler_name: |
| scheduler = create_noise_scheduler( |
| scheduler_name, |
| prediction_type="v_prediction" if v_pred else "epsilon", |
| ) |
| else: |
| scheduler = None |
|
|
| return tokenizer, text_encoder, unet, scheduler, vae |
|
|
|
|
| def load_diffusers_model_xl( |
| pretrained_model_name_or_path: str, |
| weight_dtype: torch.dtype = torch.float32, |
| ) -> Tuple[List[CLIPTokenizer], List[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: |
| |
|
|
| tokenizers = [ |
| CLIPTokenizer.from_pretrained( |
| pretrained_model_name_or_path, |
| subfolder="tokenizer", |
| torch_dtype=weight_dtype, |
| cache_dir=DIFFUSERS_CACHE_DIR, |
| ), |
| CLIPTokenizer.from_pretrained( |
| pretrained_model_name_or_path, |
| subfolder="tokenizer_2", |
| torch_dtype=weight_dtype, |
| cache_dir=DIFFUSERS_CACHE_DIR, |
| pad_token_id=0, |
| ), |
| ] |
|
|
| text_encoders = [ |
| CLIPTextModel.from_pretrained( |
| pretrained_model_name_or_path, |
| subfolder="text_encoder", |
| torch_dtype=weight_dtype, |
| cache_dir=DIFFUSERS_CACHE_DIR, |
| ), |
| CLIPTextModelWithProjection.from_pretrained( |
| pretrained_model_name_or_path, |
| subfolder="text_encoder_2", |
| torch_dtype=weight_dtype, |
| cache_dir=DIFFUSERS_CACHE_DIR, |
| ), |
| ] |
|
|
| unet = UNet2DConditionModel.from_pretrained( |
| pretrained_model_name_or_path, |
| subfolder="unet", |
| torch_dtype=weight_dtype, |
| cache_dir=DIFFUSERS_CACHE_DIR, |
| ) |
| vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") |
| return tokenizers, text_encoders, unet, vae |
|
|
|
|
| def load_checkpoint_model_xl( |
| checkpoint_path: str, |
| weight_dtype: torch.dtype = torch.float32, |
| ) -> Tuple[List[CLIPTokenizer], List[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: |
| pipe = StableDiffusionXLPipeline.from_single_file( |
| checkpoint_path, |
| torch_dtype=weight_dtype, |
| cache_dir=DIFFUSERS_CACHE_DIR, |
| ) |
|
|
| unet = pipe.unet |
| vae = pipe.vae |
| tokenizers = [pipe.tokenizer, pipe.tokenizer_2] |
| text_encoders = [pipe.text_encoder, pipe.text_encoder_2] |
| if len(text_encoders) == 2: |
| text_encoders[1].pad_token_id = 0 |
|
|
| del pipe |
|
|
| return tokenizers, text_encoders, unet, vae |
|
|
|
|
| def load_models_xl( |
| pretrained_model_name_or_path: str, |
| scheduler_name: str, |
| weight_dtype: torch.dtype = torch.float32, |
| noise_scheduler_kwargs=None, |
| ) -> Tuple[ |
| List[CLIPTokenizer], |
| List[SDXL_TEXT_ENCODER_TYPE], |
| UNet2DConditionModel, |
| SchedulerMixin, |
| ]: |
| if pretrained_model_name_or_path.endswith( |
| ".ckpt" |
| ) or pretrained_model_name_or_path.endswith(".safetensors"): |
| (tokenizers, text_encoders, unet, vae) = load_checkpoint_model_xl( |
| pretrained_model_name_or_path, weight_dtype |
| ) |
| else: |
| (tokenizers, text_encoders, unet, vae) = load_diffusers_model_xl( |
| pretrained_model_name_or_path, weight_dtype |
| ) |
| if scheduler_name: |
| scheduler = create_noise_scheduler(scheduler_name, noise_scheduler_kwargs) |
| else: |
| scheduler = None |
|
|
| return tokenizers, text_encoders, unet, scheduler, vae |
|
|
| def create_noise_scheduler( |
| scheduler_name: AVAILABLE_SCHEDULERS = "ddpm", |
| noise_scheduler_kwargs=None, |
| prediction_type: Literal["epsilon", "v_prediction"] = "epsilon", |
| ) -> SchedulerMixin: |
| name = scheduler_name.lower().replace(" ", "_") |
| if name.lower() == "ddim": |
| |
| scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) |
| elif name.lower() == "ddpm": |
| |
| scheduler = DDPMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) |
| elif name.lower() == "lms": |
| |
| scheduler = LMSDiscreteScheduler( |
| **OmegaConf.to_container(noise_scheduler_kwargs) |
| ) |
| elif name.lower() == "euler_a": |
| |
| scheduler = EulerAncestralDiscreteScheduler( |
| **OmegaConf.to_container(noise_scheduler_kwargs) |
| ) |
| elif name.lower() == "euler": |
| |
| scheduler = EulerDiscreteScheduler( |
| **OmegaConf.to_container(noise_scheduler_kwargs) |
| ) |
| elif name.lower() == "unipc": |
| |
| scheduler = UniPCMultistepScheduler( |
| **OmegaConf.to_container(noise_scheduler_kwargs) |
| ) |
| else: |
| raise ValueError(f"Unknown scheduler name: {name}") |
|
|
| return scheduler |
|
|
|
|
| def torch_gc(): |
| import gc |
|
|
| gc.collect() |
| if torch.cuda.is_available(): |
| with torch.cuda.device("cuda"): |
| torch.cuda.empty_cache() |
| torch.cuda.ipc_collect() |
|
|
|
|
| from enum import Enum |
|
|
|
|
| class CPUState(Enum): |
| GPU = 0 |
| CPU = 1 |
| MPS = 2 |
|
|
|
|
| cpu_state = CPUState.GPU |
| xpu_available = False |
| directml_enabled = False |
|
|
|
|
| def is_intel_xpu(): |
| global cpu_state |
| global xpu_available |
| if cpu_state == CPUState.GPU: |
| if xpu_available: |
| return True |
| return False |
|
|
|
|
| try: |
| import intel_extension_for_pytorch as ipex |
|
|
| if torch.xpu.is_available(): |
| xpu_available = True |
| except: |
| pass |
|
|
| try: |
| if torch.backends.mps.is_available(): |
| cpu_state = CPUState.MPS |
| import torch.mps |
| except: |
| pass |
|
|
|
|
| def get_torch_device(): |
| global directml_enabled |
| global cpu_state |
| if directml_enabled: |
| global directml_device |
| return directml_device |
| if cpu_state == CPUState.MPS: |
| return torch.device("mps") |
| if cpu_state == CPUState.CPU: |
| return torch.device("cpu") |
| else: |
| if is_intel_xpu(): |
| return torch.device("xpu") |
| else: |
| return torch.device(torch.cuda.current_device()) |