Spaces:
Running
on
Zero
Running
on
Zero
| import copy | |
| import gc | |
| import json | |
| import random | |
| import shutil | |
| import typing | |
| from typing import Union, List, Literal, Iterator | |
| import sys | |
| import os | |
| from collections import OrderedDict | |
| import copy | |
| import yaml | |
| from PIL import Image | |
| from diffusers.pipelines.pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, \ | |
| ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN | |
| from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg | |
| from safetensors.torch import save_file, load_file | |
| from torch import autocast | |
| from torch.nn import Parameter | |
| from torch.utils.checkpoint import checkpoint | |
| from tqdm import tqdm | |
| from torchvision.transforms import Resize, transforms | |
| from toolkit.assistant_lora import load_assistant_lora_from_path | |
| from toolkit.clip_vision_adapter import ClipVisionAdapter | |
| from toolkit.custom_adapter import CustomAdapter | |
| from toolkit.ip_adapter import IPAdapter | |
| from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \ | |
| convert_vae_state_dict, load_vae | |
| from toolkit import train_tools | |
| from toolkit.config_modules import ModelConfig, GenerateImageConfig | |
| from toolkit.metadata import get_meta_for_safetensors | |
| from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT | |
| from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds | |
| from toolkit.reference_adapter import ReferenceAdapter | |
| from toolkit.sampler import get_sampler | |
| from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler | |
| from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers | |
| from toolkit.sd_device_states_presets import empty_preset | |
| from toolkit.train_tools import get_torch_dtype, apply_noise_offset | |
| from einops import rearrange, repeat | |
| import torch | |
| from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \ | |
| StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline | |
| from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \ | |
| StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \ | |
| StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \ | |
| StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \ | |
| StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \ | |
| FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel | |
| import diffusers | |
| from diffusers import \ | |
| AutoencoderKL, \ | |
| UNet2DConditionModel | |
| from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler, PixArtSigmaPipeline | |
| from transformers import T5EncoderModel, BitsAndBytesConfig, UMT5EncoderModel, T5TokenizerFast | |
| from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection | |
| from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT | |
| from huggingface_hub import hf_hub_download | |
| from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4 | |
| from typing import TYPE_CHECKING | |
| if TYPE_CHECKING: | |
| from toolkit.lora_special import LoRASpecialNetwork | |
| # tell it to shut up | |
| diffusers.logging.set_verbosity(diffusers.logging.ERROR) | |
| SD_PREFIX_VAE = "vae" | |
| SD_PREFIX_UNET = "unet" | |
| SD_PREFIX_REFINER_UNET = "refiner_unet" | |
| SD_PREFIX_TEXT_ENCODER = "te" | |
| SD_PREFIX_TEXT_ENCODER1 = "te0" | |
| SD_PREFIX_TEXT_ENCODER2 = "te1" | |
| # prefixed diffusers keys | |
| DO_NOT_TRAIN_WEIGHTS = [ | |
| "unet_time_embedding.linear_1.bias", | |
| "unet_time_embedding.linear_1.weight", | |
| "unet_time_embedding.linear_2.bias", | |
| "unet_time_embedding.linear_2.weight", | |
| "refiner_unet_time_embedding.linear_1.bias", | |
| "refiner_unet_time_embedding.linear_1.weight", | |
| "refiner_unet_time_embedding.linear_2.bias", | |
| "refiner_unet_time_embedding.linear_2.weight", | |
| ] | |
| DeviceStatePreset = Literal['cache_latents', 'generate'] | |
| class BlankNetwork: | |
| def __init__(self): | |
| self.multiplier = 1.0 | |
| self.is_active = True | |
| self.is_merged_in = False | |
| self.can_merge_in = False | |
| def __enter__(self): | |
| self.is_active = True | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| self.is_active = False | |
| def flush(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。 | |
| # VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 | |
| class StableDiffusion: | |
| def __init__( | |
| self, | |
| device, | |
| model_config: ModelConfig, | |
| dtype='fp16', | |
| custom_pipeline=None, | |
| noise_scheduler=None, | |
| quantize_device=None, | |
| ): | |
| self.custom_pipeline = custom_pipeline | |
| self.device = device | |
| self.dtype = dtype | |
| self.torch_dtype = get_torch_dtype(dtype) | |
| self.device_torch = torch.device(self.device) | |
| self.vae_device_torch = torch.device(self.device) if model_config.vae_device is None else torch.device( | |
| model_config.vae_device) | |
| self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype) | |
| self.te_device_torch = torch.device(self.device) if model_config.te_device is None else torch.device( | |
| model_config.te_device) | |
| self.te_torch_dtype = get_torch_dtype(model_config.te_dtype) | |
| self.model_config = model_config | |
| self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon" | |
| self.device_state = None | |
| self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline'] | |
| self.vae: Union[None, 'AutoencoderKL'] | |
| self.unet: Union[None, 'UNet2DConditionModel'] | |
| self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]] | |
| self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']] | |
| self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler | |
| self.refiner_unet: Union[None, 'UNet2DConditionModel'] = None | |
| self.assistant_lora: Union[None, 'LoRASpecialNetwork'] = None | |
| # sdxl stuff | |
| self.logit_scale = None | |
| self.ckppt_info = None | |
| self.is_loaded = False | |
| # to hold network if there is one | |
| self.network = None | |
| self.adapter: Union['ControlNetModel', 'T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None | |
| self.is_xl = model_config.is_xl | |
| self.is_v2 = model_config.is_v2 | |
| self.is_ssd = model_config.is_ssd | |
| self.is_v3 = model_config.is_v3 | |
| self.is_vega = model_config.is_vega | |
| self.is_pixart = model_config.is_pixart | |
| self.is_auraflow = model_config.is_auraflow | |
| self.is_flux = model_config.is_flux | |
| self.use_text_encoder_1 = model_config.use_text_encoder_1 | |
| self.use_text_encoder_2 = model_config.use_text_encoder_2 | |
| self.config_file = None | |
| self.is_flow_matching = False | |
| if self.is_flux or self.is_v3 or self.is_auraflow or isinstance(self.noise_scheduler, CustomFlowMatchEulerDiscreteScheduler): | |
| self.is_flow_matching = True | |
| self.quantize_device = quantize_device if quantize_device is not None else self.device | |
| self.low_vram = self.model_config.low_vram | |
| # merge in and preview active with -1 weight | |
| self.invert_assistant_lora = False | |
| def load_model(self): | |
| if self.is_loaded: | |
| return | |
| dtype = get_torch_dtype(self.dtype) | |
| # move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why | |
| # self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch) | |
| # self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch) | |
| # self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch) | |
| model_path = self.model_config.name_or_path | |
| if 'civitai.com' in self.model_config.name_or_path: | |
| # load is a civit ai model, use the loader. | |
| from toolkit.civitai import get_model_path_from_url | |
| model_path = get_model_path_from_url(self.model_config.name_or_path) | |
| load_args = {} | |
| if self.noise_scheduler: | |
| load_args['scheduler'] = self.noise_scheduler | |
| if self.model_config.vae_path is not None: | |
| load_args['vae'] = load_vae(self.model_config.vae_path, dtype) | |
| if self.model_config.is_xl or self.model_config.is_ssd or self.model_config.is_vega: | |
| if self.custom_pipeline is not None: | |
| pipln = self.custom_pipeline | |
| else: | |
| pipln = StableDiffusionXLPipeline | |
| # pipln = StableDiffusionKDiffusionXLPipeline | |
| # see if path exists | |
| if not os.path.exists(model_path) or os.path.isdir(model_path): | |
| # try to load with default diffusers | |
| pipe = pipln.from_pretrained( | |
| model_path, | |
| dtype=dtype, | |
| device=self.device_torch, | |
| # variant="fp16", | |
| use_safetensors=True, | |
| **load_args | |
| ) | |
| else: | |
| pipe = pipln.from_single_file( | |
| model_path, | |
| device=self.device_torch, | |
| torch_dtype=self.torch_dtype, | |
| ) | |
| if 'vae' in load_args and load_args['vae'] is not None: | |
| pipe.vae = load_args['vae'] | |
| flush() | |
| text_encoders = [pipe.text_encoder, pipe.text_encoder_2] | |
| tokenizer = [pipe.tokenizer, pipe.tokenizer_2] | |
| for text_encoder in text_encoders: | |
| text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype) | |
| text_encoder.requires_grad_(False) | |
| text_encoder.eval() | |
| text_encoder = text_encoders | |
| pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) | |
| if self.model_config.experimental_xl: | |
| print("Experimental XL mode enabled") | |
| print("Loading and injecting alt weights") | |
| # load the mismatched weight and force it in | |
| raw_state_dict = load_file(model_path) | |
| replacement_weight = raw_state_dict['conditioner.embedders.1.model.text_projection'].clone() | |
| del raw_state_dict | |
| # get state dict for for 2nd text encoder | |
| te1_state_dict = text_encoders[1].state_dict() | |
| # replace weight with mismatched weight | |
| te1_state_dict['text_projection.weight'] = replacement_weight.to(self.device_torch, dtype=dtype) | |
| flush() | |
| print("Injecting alt weights") | |
| elif self.model_config.is_v3: | |
| if self.custom_pipeline is not None: | |
| pipln = self.custom_pipeline | |
| else: | |
| pipln = StableDiffusion3Pipeline | |
| print("Loading SD3 model") | |
| # assume it is the large model | |
| base_model_path = "stabilityai/stable-diffusion-3.5-large" | |
| print("Loading transformer") | |
| subfolder = 'transformer' | |
| transformer_path = model_path | |
| # check if HF_DATASETS_OFFLINE or TRANSFORMERS_OFFLINE is set | |
| if os.path.exists(transformer_path): | |
| subfolder = None | |
| transformer_path = os.path.join(transformer_path, 'transformer') | |
| # check if the path is a full checkpoint. | |
| te_folder_path = os.path.join(model_path, 'text_encoder') | |
| # if we have the te, this folder is a full checkpoint, use it as the base | |
| if os.path.exists(te_folder_path): | |
| base_model_path = model_path | |
| else: | |
| # is remote use whatever path we were given | |
| base_model_path = model_path | |
| transformer = SD3Transformer2DModel.from_pretrained( | |
| transformer_path, | |
| subfolder=subfolder, | |
| torch_dtype=dtype, | |
| ) | |
| if not self.low_vram: | |
| # for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu | |
| transformer.to(torch.device(self.quantize_device), dtype=dtype) | |
| flush() | |
| if self.model_config.lora_path is not None: | |
| raise ValueError("LoRA is not supported for SD3 models currently") | |
| if self.model_config.quantize: | |
| quantization_type = qfloat8 | |
| print("Quantizing transformer") | |
| quantize(transformer, weights=quantization_type) | |
| freeze(transformer) | |
| transformer.to(self.device_torch) | |
| else: | |
| transformer.to(self.device_torch, dtype=dtype) | |
| scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") | |
| print("Loading vae") | |
| vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) | |
| flush() | |
| print("Loading t5") | |
| tokenizer_3 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_3", torch_dtype=dtype) | |
| text_encoder_3 = T5EncoderModel.from_pretrained( | |
| base_model_path, | |
| subfolder="text_encoder_3", | |
| torch_dtype=dtype | |
| ) | |
| text_encoder_3.to(self.device_torch, dtype=dtype) | |
| flush() | |
| if self.model_config.quantize: | |
| print("Quantizing T5") | |
| quantize(text_encoder_3, weights=qfloat8) | |
| freeze(text_encoder_3) | |
| flush() | |
| # see if path exists | |
| if not os.path.exists(model_path) or os.path.isdir(model_path): | |
| try: | |
| # try to load with default diffusers | |
| pipe = pipln.from_pretrained( | |
| base_model_path, | |
| dtype=dtype, | |
| device=self.device_torch, | |
| tokenizer_3=tokenizer_3, | |
| text_encoder_3=text_encoder_3, | |
| transformer=transformer, | |
| # variant="fp16", | |
| use_safetensors=True, | |
| repo_type="model", | |
| ignore_patterns=["*.md", "*..gitattributes"], | |
| **load_args | |
| ) | |
| except Exception as e: | |
| print(f"Error loading from pretrained: {e}") | |
| raise e | |
| else: | |
| pipe = pipln.from_single_file( | |
| model_path, | |
| transformer=transformer, | |
| device=self.device_torch, | |
| torch_dtype=self.torch_dtype, | |
| tokenizer_3=tokenizer_3, | |
| text_encoder_3=text_encoder_3, | |
| **load_args | |
| ) | |
| flush() | |
| text_encoders = [pipe.text_encoder, pipe.text_encoder_2, pipe.text_encoder_3] | |
| tokenizer = [pipe.tokenizer, pipe.tokenizer_2, pipe.tokenizer_3] | |
| # replace the to function with a no-op since it throws an error instead of a warning | |
| # text_encoders[2].to = lambda *args, **kwargs: None | |
| for text_encoder in text_encoders: | |
| text_encoder.to(self.device_torch, dtype=dtype) | |
| text_encoder.requires_grad_(False) | |
| text_encoder.eval() | |
| text_encoder = text_encoders | |
| elif self.model_config.is_pixart: | |
| te_kwargs = {} | |
| # handle quantization of TE | |
| te_is_quantized = False | |
| if self.model_config.text_encoder_bits == 8: | |
| te_kwargs['load_in_8bit'] = True | |
| te_kwargs['device_map'] = "auto" | |
| te_is_quantized = True | |
| elif self.model_config.text_encoder_bits == 4: | |
| te_kwargs['load_in_4bit'] = True | |
| te_kwargs['device_map'] = "auto" | |
| te_is_quantized = True | |
| main_model_path = "PixArt-alpha/PixArt-XL-2-1024-MS" | |
| if self.model_config.is_pixart_sigma: | |
| main_model_path = "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers" | |
| main_model_path = model_path | |
| # load the TE in 8bit mode | |
| text_encoder = T5EncoderModel.from_pretrained( | |
| main_model_path, | |
| subfolder="text_encoder", | |
| torch_dtype=self.torch_dtype, | |
| **te_kwargs | |
| ) | |
| # load the transformer | |
| subfolder = "transformer" | |
| # check if it is just the unet | |
| if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)): | |
| subfolder = None | |
| if te_is_quantized: | |
| # replace the to function with a no-op since it throws an error instead of a warning | |
| text_encoder.to = lambda *args, **kwargs: None | |
| text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype) | |
| if self.model_config.is_pixart_sigma: | |
| # load the transformer only from the save | |
| transformer = Transformer2DModel.from_pretrained( | |
| model_path if self.model_config.unet_path is None else self.model_config.unet_path, | |
| torch_dtype=self.torch_dtype, | |
| subfolder='transformer' | |
| ) | |
| pipe: PixArtSigmaPipeline = PixArtSigmaPipeline.from_pretrained( | |
| main_model_path, | |
| transformer=transformer, | |
| text_encoder=text_encoder, | |
| dtype=dtype, | |
| device=self.device_torch, | |
| **load_args | |
| ) | |
| else: | |
| # load the transformer only from the save | |
| transformer = Transformer2DModel.from_pretrained(model_path, torch_dtype=self.torch_dtype, | |
| subfolder=subfolder) | |
| pipe: PixArtAlphaPipeline = PixArtAlphaPipeline.from_pretrained( | |
| main_model_path, | |
| transformer=transformer, | |
| text_encoder=text_encoder, | |
| dtype=dtype, | |
| device=self.device_torch, | |
| **load_args | |
| ).to(self.device_torch) | |
| if self.model_config.unet_sample_size is not None: | |
| pipe.transformer.config.sample_size = self.model_config.unet_sample_size | |
| pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) | |
| flush() | |
| # text_encoder = pipe.text_encoder | |
| # text_encoder.to(self.device_torch, dtype=dtype) | |
| text_encoder.requires_grad_(False) | |
| text_encoder.eval() | |
| pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) | |
| tokenizer = pipe.tokenizer | |
| pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) | |
| if self.noise_scheduler is None: | |
| self.noise_scheduler = pipe.scheduler | |
| elif self.model_config.is_auraflow: | |
| te_kwargs = {} | |
| # handle quantization of TE | |
| te_is_quantized = False | |
| if self.model_config.text_encoder_bits == 8: | |
| te_kwargs['load_in_8bit'] = True | |
| te_kwargs['device_map'] = "auto" | |
| te_is_quantized = True | |
| elif self.model_config.text_encoder_bits == 4: | |
| te_kwargs['load_in_4bit'] = True | |
| te_kwargs['device_map'] = "auto" | |
| te_is_quantized = True | |
| main_model_path = model_path | |
| # load the TE in 8bit mode | |
| text_encoder = UMT5EncoderModel.from_pretrained( | |
| main_model_path, | |
| subfolder="text_encoder", | |
| torch_dtype=self.torch_dtype, | |
| **te_kwargs | |
| ) | |
| # load the transformer | |
| subfolder = "transformer" | |
| # check if it is just the unet | |
| if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)): | |
| subfolder = None | |
| if te_is_quantized: | |
| # replace the to function with a no-op since it throws an error instead of a warning | |
| text_encoder.to = lambda *args, **kwargs: None | |
| # load the transformer only from the save | |
| transformer = AuraFlowTransformer2DModel.from_pretrained( | |
| model_path if self.model_config.unet_path is None else self.model_config.unet_path, | |
| torch_dtype=self.torch_dtype, | |
| subfolder='transformer' | |
| ) | |
| pipe: AuraFlowPipeline = AuraFlowPipeline.from_pretrained( | |
| main_model_path, | |
| transformer=transformer, | |
| text_encoder=text_encoder, | |
| dtype=dtype, | |
| device=self.device_torch, | |
| **load_args | |
| ) | |
| pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) | |
| # patch auraflow so it can handle other aspect ratios | |
| # patch_auraflow_pos_embed(pipe.transformer.pos_embed) | |
| flush() | |
| # text_encoder = pipe.text_encoder | |
| # text_encoder.to(self.device_torch, dtype=dtype) | |
| text_encoder.requires_grad_(False) | |
| text_encoder.eval() | |
| pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) | |
| tokenizer = pipe.tokenizer | |
| elif self.model_config.is_flux: | |
| print("Loading Flux model") | |
| base_model_path = "black-forest-labs/FLUX.1-schnell" | |
| print("Loading transformer") | |
| subfolder = 'transformer' | |
| transformer_path = model_path | |
| local_files_only = False | |
| # check if HF_DATASETS_OFFLINE or TRANSFORMERS_OFFLINE is set | |
| if os.path.exists(transformer_path): | |
| subfolder = None | |
| transformer_path = os.path.join(transformer_path, 'transformer') | |
| # check if the path is a full checkpoint. | |
| te_folder_path = os.path.join(model_path, 'text_encoder') | |
| # if we have the te, this folder is a full checkpoint, use it as the base | |
| if os.path.exists(te_folder_path): | |
| base_model_path = model_path | |
| transformer = FluxTransformer2DModel.from_pretrained( | |
| transformer_path, | |
| subfolder=subfolder, | |
| torch_dtype=dtype, | |
| # low_cpu_mem_usage=False, | |
| # device_map=None | |
| ) | |
| if not self.low_vram: | |
| # for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu | |
| transformer.to(torch.device(self.quantize_device), dtype=dtype) | |
| flush() | |
| if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None: | |
| if self.model_config.inference_lora_path is not None and self.model_config.assistant_lora_path is not None: | |
| raise ValueError("Cannot load both assistant lora and inference lora at the same time") | |
| if self.model_config.lora_path: | |
| raise ValueError("Cannot load both assistant lora and lora at the same time") | |
| if not self.is_flux: | |
| raise ValueError("Assistant/ inference lora is only supported for flux models currently") | |
| load_lora_path = self.model_config.inference_lora_path | |
| if load_lora_path is None: | |
| load_lora_path = self.model_config.assistant_lora_path | |
| if os.path.isdir(load_lora_path): | |
| load_lora_path = os.path.join( | |
| load_lora_path, "pytorch_lora_weights.safetensors" | |
| ) | |
| elif not os.path.exists(load_lora_path): | |
| print(f"Grabbing lora from the hub: {load_lora_path}") | |
| new_lora_path = hf_hub_download( | |
| load_lora_path, | |
| filename="pytorch_lora_weights.safetensors" | |
| ) | |
| # replace the path | |
| load_lora_path = new_lora_path | |
| if self.model_config.inference_lora_path is not None: | |
| self.model_config.inference_lora_path = new_lora_path | |
| if self.model_config.assistant_lora_path is not None: | |
| self.model_config.assistant_lora_path = new_lora_path | |
| if self.model_config.assistant_lora_path is not None: | |
| # for flux, we assume it is flux schnell. We cannot merge in the assistant lora and unmerge it on | |
| # quantized weights so it had to process unmerged (slow). Since schnell samples in just 4 steps | |
| # it is better to merge it in now, and sample slowly later, otherwise training is slowed in half | |
| # so we will merge in now and sample with -1 weight later | |
| self.invert_assistant_lora = True | |
| # trigger it to get merged in | |
| self.model_config.lora_path = self.model_config.assistant_lora_path | |
| if self.model_config.lora_path is not None: | |
| print("Fusing in LoRA") | |
| # need the pipe for peft | |
| pipe: FluxPipeline = FluxPipeline( | |
| scheduler=None, | |
| text_encoder=None, | |
| tokenizer=None, | |
| text_encoder_2=None, | |
| tokenizer_2=None, | |
| vae=None, | |
| transformer=transformer, | |
| ) | |
| if self.low_vram: | |
| # we cannot fuse the loras all at once without ooming in lowvram mode, so we have to do it in parts | |
| # we can do it on the cpu but it takes about 5-10 mins vs seconds on the gpu | |
| # we are going to separate it into the two transformer blocks one at a time | |
| lora_state_dict = load_file(self.model_config.lora_path) | |
| single_transformer_lora = {} | |
| single_block_key = "transformer.single_transformer_blocks." | |
| double_transformer_lora = {} | |
| double_block_key = "transformer.transformer_blocks." | |
| for key, value in lora_state_dict.items(): | |
| if single_block_key in key: | |
| single_transformer_lora[key] = value | |
| elif double_block_key in key: | |
| double_transformer_lora[key] = value | |
| else: | |
| raise ValueError(f"Unknown lora key: {key}. Cannot load this lora in low vram mode") | |
| # double blocks | |
| transformer.transformer_blocks = transformer.transformer_blocks.to( | |
| torch.device(self.quantize_device), dtype=dtype | |
| ) | |
| pipe.load_lora_weights(double_transformer_lora, adapter_name=f"lora1_double") | |
| pipe.fuse_lora() | |
| pipe.unload_lora_weights() | |
| transformer.transformer_blocks = transformer.transformer_blocks.to( | |
| 'cpu', dtype=dtype | |
| ) | |
| # single blocks | |
| transformer.single_transformer_blocks = transformer.single_transformer_blocks.to( | |
| torch.device(self.quantize_device), dtype=dtype | |
| ) | |
| pipe.load_lora_weights(single_transformer_lora, adapter_name=f"lora1_single") | |
| pipe.fuse_lora() | |
| pipe.unload_lora_weights() | |
| transformer.single_transformer_blocks = transformer.single_transformer_blocks.to( | |
| 'cpu', dtype=dtype | |
| ) | |
| # cleanup | |
| del single_transformer_lora | |
| del double_transformer_lora | |
| del lora_state_dict | |
| flush() | |
| else: | |
| # need the pipe to do this unfortunately for now | |
| # we have to fuse in the weights before quantizing | |
| pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1") | |
| pipe.fuse_lora() | |
| # unfortunately, not an easier way with peft | |
| pipe.unload_lora_weights() | |
| flush() | |
| if self.model_config.quantize: | |
| quantization_type = qfloat8 | |
| print("Quantizing transformer") | |
| quantize(transformer, weights=quantization_type) | |
| freeze(transformer) | |
| transformer.to(self.device_torch) | |
| else: | |
| transformer.to(self.device_torch, dtype=dtype) | |
| flush() | |
| scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") | |
| print("Loading vae") | |
| vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) | |
| flush() | |
| print("Loading t5") | |
| tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype) | |
| text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", | |
| torch_dtype=dtype) | |
| text_encoder_2.to(self.device_torch, dtype=dtype) | |
| flush() | |
| print("Quantizing T5") | |
| quantize(text_encoder_2, weights=qfloat8) | |
| freeze(text_encoder_2) | |
| flush() | |
| print("Loading clip") | |
| text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype) | |
| tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype) | |
| text_encoder.to(self.device_torch, dtype=dtype) | |
| print("making pipe") | |
| pipe: FluxPipeline = FluxPipeline( | |
| scheduler=scheduler, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| text_encoder_2=None, | |
| tokenizer_2=tokenizer_2, | |
| vae=vae, | |
| transformer=None, | |
| ) | |
| pipe.text_encoder_2 = text_encoder_2 | |
| pipe.transformer = transformer | |
| print("preparing") | |
| text_encoder = [pipe.text_encoder, pipe.text_encoder_2] | |
| tokenizer = [pipe.tokenizer, pipe.tokenizer_2] | |
| pipe.transformer = pipe.transformer.to(self.device_torch) | |
| flush() | |
| text_encoder[0].to(self.device_torch) | |
| text_encoder[0].requires_grad_(False) | |
| text_encoder[0].eval() | |
| text_encoder[1].to(self.device_torch) | |
| text_encoder[1].requires_grad_(False) | |
| text_encoder[1].eval() | |
| pipe.transformer = pipe.transformer.to(self.device_torch) | |
| flush() | |
| else: | |
| if self.custom_pipeline is not None: | |
| pipln = self.custom_pipeline | |
| else: | |
| pipln = StableDiffusionPipeline | |
| if self.model_config.text_encoder_bits < 16: | |
| # this is only supported for T5 models for now | |
| te_kwargs = {} | |
| # handle quantization of TE | |
| te_is_quantized = False | |
| if self.model_config.text_encoder_bits == 8: | |
| te_kwargs['load_in_8bit'] = True | |
| te_kwargs['device_map'] = "auto" | |
| te_is_quantized = True | |
| elif self.model_config.text_encoder_bits == 4: | |
| te_kwargs['load_in_4bit'] = True | |
| te_kwargs['device_map'] = "auto" | |
| te_is_quantized = True | |
| text_encoder = T5EncoderModel.from_pretrained( | |
| model_path, | |
| subfolder="text_encoder", | |
| torch_dtype=self.te_torch_dtype, | |
| **te_kwargs | |
| ) | |
| # replace the to function with a no-op since it throws an error instead of a warning | |
| text_encoder.to = lambda *args, **kwargs: None | |
| load_args['text_encoder'] = text_encoder | |
| # see if path exists | |
| if not os.path.exists(model_path) or os.path.isdir(model_path): | |
| # try to load with default diffusers | |
| pipe = pipln.from_pretrained( | |
| model_path, | |
| dtype=dtype, | |
| device=self.device_torch, | |
| load_safety_checker=False, | |
| requires_safety_checker=False, | |
| safety_checker=None, | |
| # variant="fp16", | |
| trust_remote_code=True, | |
| **load_args | |
| ) | |
| else: | |
| pipe = pipln.from_single_file( | |
| model_path, | |
| dtype=dtype, | |
| device=self.device_torch, | |
| load_safety_checker=False, | |
| requires_safety_checker=False, | |
| torch_dtype=self.torch_dtype, | |
| safety_checker=None, | |
| trust_remote_code=True, | |
| **load_args | |
| ) | |
| flush() | |
| pipe.register_to_config(requires_safety_checker=False) | |
| text_encoder = pipe.text_encoder | |
| text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype) | |
| text_encoder.requires_grad_(False) | |
| text_encoder.eval() | |
| tokenizer = pipe.tokenizer | |
| # scheduler doesn't get set sometimes, so we set it here | |
| pipe.scheduler = self.noise_scheduler | |
| # add hacks to unet to help training | |
| # pipe.unet = prepare_unet_for_training(pipe.unet) | |
| if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux: | |
| # pixart and sd3 dont use a unet | |
| self.unet = pipe.transformer | |
| else: | |
| self.unet: 'UNet2DConditionModel' = pipe.unet | |
| self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) | |
| self.vae.eval() | |
| self.vae.requires_grad_(False) | |
| VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) | |
| self.vae_scale_factor = VAE_SCALE_FACTOR | |
| self.unet.to(self.device_torch, dtype=dtype) | |
| self.unet.requires_grad_(False) | |
| self.unet.eval() | |
| # load any loras we have | |
| if self.model_config.lora_path is not None and not self.is_flux: | |
| pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1") | |
| pipe.fuse_lora() | |
| # unfortunately, not an easier way with peft | |
| pipe.unload_lora_weights() | |
| self.tokenizer = tokenizer | |
| self.text_encoder = text_encoder | |
| self.pipeline = pipe | |
| self.load_refiner() | |
| self.is_loaded = True | |
| if self.model_config.assistant_lora_path is not None: | |
| print("Loading assistant lora") | |
| self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path( | |
| self.model_config.assistant_lora_path, self) | |
| if self.invert_assistant_lora: | |
| # invert and disable during training | |
| self.assistant_lora.multiplier = -1.0 | |
| self.assistant_lora.is_active = False | |
| if self.model_config.inference_lora_path is not None: | |
| print("Loading inference lora") | |
| self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path( | |
| self.model_config.inference_lora_path, self) | |
| # disable during training | |
| self.assistant_lora.is_active = False | |
| if self.is_pixart and self.vae_scale_factor == 16: | |
| # TODO make our own pipeline? | |
| # we generate an image 2x larger, so we need to copy the sizes from larger ones down | |
| # ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN | |
| for key in ASPECT_RATIO_256_BIN.keys(): | |
| ASPECT_RATIO_256_BIN[key] = [ASPECT_RATIO_256_BIN[key][0] * 2, ASPECT_RATIO_256_BIN[key][1] * 2] | |
| for key in ASPECT_RATIO_512_BIN.keys(): | |
| ASPECT_RATIO_512_BIN[key] = [ASPECT_RATIO_512_BIN[key][0] * 2, ASPECT_RATIO_512_BIN[key][1] * 2] | |
| for key in ASPECT_RATIO_1024_BIN.keys(): | |
| ASPECT_RATIO_1024_BIN[key] = [ASPECT_RATIO_1024_BIN[key][0] * 2, ASPECT_RATIO_1024_BIN[key][1] * 2] | |
| for key in ASPECT_RATIO_2048_BIN.keys(): | |
| ASPECT_RATIO_2048_BIN[key] = [ASPECT_RATIO_2048_BIN[key][0] * 2, ASPECT_RATIO_2048_BIN[key][1] * 2] | |
| def te_train(self): | |
| if isinstance(self.text_encoder, list): | |
| for te in self.text_encoder: | |
| te.train() | |
| else: | |
| self.text_encoder.train() | |
| def te_eval(self): | |
| if isinstance(self.text_encoder, list): | |
| for te in self.text_encoder: | |
| te.eval() | |
| else: | |
| self.text_encoder.eval() | |
| def load_refiner(self): | |
| # for now, we are just going to rely on the TE from the base model | |
| # which is TE2 for SDXL and TE for SD (no refiner currently) | |
| # and completely ignore a TE that may or may not be packaged with the refiner | |
| if self.model_config.refiner_name_or_path is not None: | |
| refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml') | |
| # load the refiner model | |
| dtype = get_torch_dtype(self.dtype) | |
| model_path = self.model_config.refiner_name_or_path | |
| if not os.path.exists(model_path) or os.path.isdir(model_path): | |
| # TODO only load unet?? | |
| refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( | |
| model_path, | |
| dtype=dtype, | |
| device=self.device_torch, | |
| # variant="fp16", | |
| use_safetensors=True, | |
| ).to(self.device_torch) | |
| else: | |
| refiner = StableDiffusionXLImg2ImgPipeline.from_single_file( | |
| model_path, | |
| dtype=dtype, | |
| device=self.device_torch, | |
| torch_dtype=self.torch_dtype, | |
| original_config_file=refiner_config_path, | |
| ).to(self.device_torch) | |
| self.refiner_unet = refiner.unet | |
| del refiner | |
| flush() | |
| def generate_images( | |
| self, | |
| image_configs: List[GenerateImageConfig], | |
| sampler=None, | |
| pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None, | |
| ): | |
| merge_multiplier = 1.0 | |
| flush() | |
| # if using assistant, unfuse it | |
| if self.model_config.assistant_lora_path is not None: | |
| print("Unloading assistant lora") | |
| if self.invert_assistant_lora: | |
| self.assistant_lora.is_active = True | |
| # move weights on to the device | |
| self.assistant_lora.force_to(self.device_torch, self.torch_dtype) | |
| else: | |
| self.assistant_lora.is_active = False | |
| if self.model_config.inference_lora_path is not None: | |
| print("Loading inference lora") | |
| self.assistant_lora.is_active = True | |
| # move weights on to the device | |
| self.assistant_lora.force_to(self.device_torch, self.torch_dtype) | |
| if self.network is not None: | |
| self.network.eval() | |
| network = self.network | |
| # check if we have the same network weight for all samples. If we do, we can merge in th | |
| # the network to drastically speed up inference | |
| unique_network_weights = set([x.network_multiplier for x in image_configs]) | |
| if len(unique_network_weights) == 1 and self.network.can_merge_in: | |
| can_merge_in = True | |
| merge_multiplier = unique_network_weights.pop() | |
| network.merge_in(merge_weight=merge_multiplier) | |
| else: | |
| network = BlankNetwork() | |
| self.save_device_state() | |
| self.set_device_state_preset('generate') | |
| # save current seed state for training | |
| rng_state = torch.get_rng_state() | |
| cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None | |
| if pipeline is None: | |
| noise_scheduler = self.noise_scheduler | |
| if sampler is not None: | |
| if sampler.startswith("sample_"): # sample_dpmpp_2m | |
| # using ksampler | |
| noise_scheduler = get_sampler( | |
| 'lms', { | |
| "prediction_type": self.prediction_type, | |
| }) | |
| else: | |
| noise_scheduler = get_sampler( | |
| sampler, | |
| { | |
| "prediction_type": self.prediction_type, | |
| }, | |
| 'sd' if not self.is_pixart else 'pixart' | |
| ) | |
| try: | |
| noise_scheduler = noise_scheduler.to(self.device_torch, self.torch_dtype) | |
| except: | |
| pass | |
| if sampler.startswith("sample_") and self.is_xl: | |
| # using kdiffusion | |
| Pipe = StableDiffusionKDiffusionXLPipeline | |
| elif self.is_xl: | |
| Pipe = StableDiffusionXLPipeline | |
| elif self.is_v3: | |
| Pipe = StableDiffusion3Pipeline | |
| else: | |
| Pipe = StableDiffusionPipeline | |
| extra_args = {} | |
| if self.adapter is not None: | |
| if isinstance(self.adapter, T2IAdapter): | |
| if self.is_xl: | |
| Pipe = StableDiffusionXLAdapterPipeline | |
| else: | |
| Pipe = StableDiffusionAdapterPipeline | |
| extra_args['adapter'] = self.adapter | |
| elif isinstance(self.adapter, ControlNetModel): | |
| if self.is_xl: | |
| Pipe = StableDiffusionXLControlNetPipeline | |
| else: | |
| Pipe = StableDiffusionControlNetPipeline | |
| extra_args['controlnet'] = self.adapter | |
| elif isinstance(self.adapter, ReferenceAdapter): | |
| # pass the noise scheduler to the adapter | |
| self.adapter.noise_scheduler = noise_scheduler | |
| else: | |
| if self.is_xl: | |
| extra_args['add_watermarker'] = False | |
| # TODO add clip skip | |
| if self.is_xl: | |
| pipeline = Pipe( | |
| vae=self.vae, | |
| unet=self.unet, | |
| text_encoder=self.text_encoder[0], | |
| text_encoder_2=self.text_encoder[1], | |
| tokenizer=self.tokenizer[0], | |
| tokenizer_2=self.tokenizer[1], | |
| scheduler=noise_scheduler, | |
| **extra_args | |
| ).to(self.device_torch) | |
| pipeline.watermark = None | |
| elif self.is_flux: | |
| if self.model_config.use_flux_cfg: | |
| pipeline = FluxWithCFGPipeline( | |
| vae=self.vae, | |
| transformer=self.unet, | |
| text_encoder=self.text_encoder[0], | |
| text_encoder_2=self.text_encoder[1], | |
| tokenizer=self.tokenizer[0], | |
| tokenizer_2=self.tokenizer[1], | |
| scheduler=noise_scheduler, | |
| **extra_args | |
| ) | |
| else: | |
| pipeline = FluxPipeline( | |
| vae=self.vae, | |
| transformer=self.unet, | |
| text_encoder=self.text_encoder[0], | |
| text_encoder_2=self.text_encoder[1], | |
| tokenizer=self.tokenizer[0], | |
| tokenizer_2=self.tokenizer[1], | |
| scheduler=noise_scheduler, | |
| **extra_args | |
| ) | |
| pipeline.watermark = None | |
| elif self.is_v3: | |
| pipeline = Pipe( | |
| vae=self.vae, | |
| transformer=self.unet, | |
| text_encoder=self.text_encoder[0], | |
| text_encoder_2=self.text_encoder[1], | |
| text_encoder_3=self.text_encoder[2], | |
| tokenizer=self.tokenizer[0], | |
| tokenizer_2=self.tokenizer[1], | |
| tokenizer_3=self.tokenizer[2], | |
| scheduler=noise_scheduler, | |
| **extra_args | |
| ) | |
| elif self.is_pixart: | |
| pipeline = PixArtSigmaPipeline( | |
| vae=self.vae, | |
| transformer=self.unet, | |
| text_encoder=self.text_encoder, | |
| tokenizer=self.tokenizer, | |
| scheduler=noise_scheduler, | |
| **extra_args | |
| ) | |
| elif self.is_auraflow: | |
| pipeline = AuraFlowPipeline( | |
| vae=self.vae, | |
| transformer=self.unet, | |
| text_encoder=self.text_encoder, | |
| tokenizer=self.tokenizer, | |
| scheduler=noise_scheduler, | |
| **extra_args | |
| ) | |
| else: | |
| pipeline = Pipe( | |
| vae=self.vae, | |
| unet=self.unet, | |
| text_encoder=self.text_encoder, | |
| tokenizer=self.tokenizer, | |
| scheduler=noise_scheduler, | |
| safety_checker=None, | |
| feature_extractor=None, | |
| requires_safety_checker=False, | |
| **extra_args | |
| ) | |
| flush() | |
| # disable progress bar | |
| pipeline.set_progress_bar_config(disable=True) | |
| if sampler.startswith("sample_"): | |
| pipeline.set_scheduler(sampler) | |
| refiner_pipeline = None | |
| if self.refiner_unet: | |
| # build refiner pipeline | |
| refiner_pipeline = StableDiffusionXLImg2ImgPipeline( | |
| vae=pipeline.vae, | |
| unet=self.refiner_unet, | |
| text_encoder=None, | |
| text_encoder_2=pipeline.text_encoder_2, | |
| tokenizer=None, | |
| tokenizer_2=pipeline.tokenizer_2, | |
| scheduler=pipeline.scheduler, | |
| add_watermarker=False, | |
| requires_aesthetics_score=True, | |
| ).to(self.device_torch) | |
| # refiner_pipeline.register_to_config(requires_aesthetics_score=False) | |
| refiner_pipeline.watermark = None | |
| refiner_pipeline.set_progress_bar_config(disable=True) | |
| flush() | |
| start_multiplier = 1.0 | |
| if self.network is not None: | |
| start_multiplier = self.network.multiplier | |
| # pipeline.to(self.device_torch) | |
| with network: | |
| with torch.no_grad(): | |
| if self.network is not None: | |
| assert self.network.is_active | |
| for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False): | |
| gen_config = image_configs[i] | |
| extra = {} | |
| validation_image = None | |
| if self.adapter is not None and gen_config.adapter_image_path is not None: | |
| validation_image = Image.open(gen_config.adapter_image_path).convert("RGB") | |
| if isinstance(self.adapter, T2IAdapter): | |
| # not sure why this is double?? | |
| validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2)) | |
| extra['image'] = validation_image | |
| extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale | |
| if isinstance(self.adapter, ControlNetModel): | |
| validation_image = validation_image.resize((gen_config.width, gen_config.height)) | |
| extra['image'] = validation_image | |
| extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale | |
| if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter): | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| ]) | |
| validation_image = transform(validation_image) | |
| if isinstance(self.adapter, CustomAdapter): | |
| # todo allow loading multiple | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| ]) | |
| validation_image = transform(validation_image) | |
| self.adapter.num_images = 1 | |
| if isinstance(self.adapter, ReferenceAdapter): | |
| # need -1 to 1 | |
| validation_image = transforms.ToTensor()(validation_image) | |
| validation_image = validation_image * 2.0 - 1.0 | |
| validation_image = validation_image.unsqueeze(0) | |
| self.adapter.set_reference_images(validation_image) | |
| if self.network is not None: | |
| self.network.multiplier = gen_config.network_multiplier | |
| torch.manual_seed(gen_config.seed) | |
| torch.cuda.manual_seed(gen_config.seed) | |
| if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter) \ | |
| and gen_config.adapter_image_path is not None: | |
| # run through the adapter to saturate the embeds | |
| conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) | |
| self.adapter(conditional_clip_embeds) | |
| if self.adapter is not None and isinstance(self.adapter, CustomAdapter): | |
| # handle condition the prompts | |
| gen_config.prompt = self.adapter.condition_prompt( | |
| gen_config.prompt, | |
| is_unconditional=False, | |
| ) | |
| gen_config.prompt_2 = gen_config.prompt | |
| gen_config.negative_prompt = self.adapter.condition_prompt( | |
| gen_config.negative_prompt, | |
| is_unconditional=True, | |
| ) | |
| gen_config.negative_prompt_2 = gen_config.negative_prompt | |
| if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None: | |
| self.adapter.trigger_pre_te( | |
| tensors_0_1=validation_image, | |
| is_training=False, | |
| has_been_preprocessed=False, | |
| quad_count=4 | |
| ) | |
| # encode the prompt ourselves so we can do fun stuff with embeddings | |
| if isinstance(self.adapter, CustomAdapter): | |
| self.adapter.is_unconditional_run = False | |
| conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True) | |
| if isinstance(self.adapter, CustomAdapter): | |
| self.adapter.is_unconditional_run = True | |
| unconditional_embeds = self.encode_prompt( | |
| gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True | |
| ) | |
| if isinstance(self.adapter, CustomAdapter): | |
| self.adapter.is_unconditional_run = False | |
| # allow any manipulations to take place to embeddings | |
| gen_config.post_process_embeddings( | |
| conditional_embeds, | |
| unconditional_embeds, | |
| ) | |
| if self.adapter is not None and isinstance(self.adapter, IPAdapter) \ | |
| and gen_config.adapter_image_path is not None: | |
| # apply the image projection | |
| conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) | |
| unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, | |
| True) | |
| conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False) | |
| unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True) | |
| if self.adapter is not None and isinstance(self.adapter, | |
| CustomAdapter) and validation_image is not None: | |
| conditional_embeds = self.adapter.condition_encoded_embeds( | |
| tensors_0_1=validation_image, | |
| prompt_embeds=conditional_embeds, | |
| is_training=False, | |
| has_been_preprocessed=False, | |
| is_generating_samples=True, | |
| ) | |
| unconditional_embeds = self.adapter.condition_encoded_embeds( | |
| tensors_0_1=validation_image, | |
| prompt_embeds=unconditional_embeds, | |
| is_training=False, | |
| has_been_preprocessed=False, | |
| is_unconditional=True, | |
| is_generating_samples=True, | |
| ) | |
| if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len( | |
| gen_config.extra_values) > 0: | |
| extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch, | |
| dtype=self.torch_dtype) | |
| # apply extra values to the embeddings | |
| self.adapter.add_extra_values(extra_values, is_unconditional=False) | |
| self.adapter.add_extra_values(torch.zeros_like(extra_values), is_unconditional=True) | |
| pass # todo remove, for debugging | |
| if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: | |
| # if we have a refiner loaded, set the denoising end at the refiner start | |
| extra['denoising_end'] = gen_config.refiner_start_at | |
| extra['output_type'] = 'latent' | |
| if not self.is_xl: | |
| raise ValueError("Refiner is only supported for XL models") | |
| conditional_embeds = conditional_embeds.to(self.device_torch, dtype=self.unet.dtype) | |
| unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=self.unet.dtype) | |
| if self.is_xl: | |
| # fix guidance rescale for sdxl | |
| # was trained on 0.7 (I believe) | |
| grs = gen_config.guidance_rescale | |
| # if grs is None or grs < 0.00001: | |
| # grs = 0.7 | |
| # grs = 0.0 | |
| if sampler.startswith("sample_"): | |
| extra['use_karras_sigmas'] = True | |
| extra = { | |
| **extra, | |
| **gen_config.extra_kwargs, | |
| } | |
| img = pipeline( | |
| # prompt=gen_config.prompt, | |
| # prompt_2=gen_config.prompt_2, | |
| prompt_embeds=conditional_embeds.text_embeds, | |
| pooled_prompt_embeds=conditional_embeds.pooled_embeds, | |
| negative_prompt_embeds=unconditional_embeds.text_embeds, | |
| negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, | |
| # negative_prompt=gen_config.negative_prompt, | |
| # negative_prompt_2=gen_config.negative_prompt_2, | |
| height=gen_config.height, | |
| width=gen_config.width, | |
| num_inference_steps=gen_config.num_inference_steps, | |
| guidance_scale=gen_config.guidance_scale, | |
| guidance_rescale=grs, | |
| latents=gen_config.latents, | |
| **extra | |
| ).images[0] | |
| elif self.is_v3: | |
| img = pipeline( | |
| prompt_embeds=conditional_embeds.text_embeds, | |
| pooled_prompt_embeds=conditional_embeds.pooled_embeds, | |
| negative_prompt_embeds=unconditional_embeds.text_embeds, | |
| negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, | |
| height=gen_config.height, | |
| width=gen_config.width, | |
| num_inference_steps=gen_config.num_inference_steps, | |
| guidance_scale=gen_config.guidance_scale, | |
| latents=gen_config.latents, | |
| **extra | |
| ).images[0] | |
| elif self.is_flux: | |
| if self.model_config.use_flux_cfg: | |
| img = pipeline( | |
| prompt_embeds=conditional_embeds.text_embeds, | |
| pooled_prompt_embeds=conditional_embeds.pooled_embeds, | |
| negative_prompt_embeds=unconditional_embeds.text_embeds, | |
| negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, | |
| height=gen_config.height, | |
| width=gen_config.width, | |
| num_inference_steps=gen_config.num_inference_steps, | |
| guidance_scale=gen_config.guidance_scale, | |
| latents=gen_config.latents, | |
| **extra | |
| ).images[0] | |
| else: | |
| img = pipeline( | |
| prompt_embeds=conditional_embeds.text_embeds, | |
| pooled_prompt_embeds=conditional_embeds.pooled_embeds, | |
| # negative_prompt_embeds=unconditional_embeds.text_embeds, | |
| # negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, | |
| height=gen_config.height, | |
| width=gen_config.width, | |
| num_inference_steps=gen_config.num_inference_steps, | |
| guidance_scale=gen_config.guidance_scale, | |
| latents=gen_config.latents, | |
| **extra | |
| ).images[0] | |
| elif self.is_pixart: | |
| # needs attention masks for some reason | |
| img = pipeline( | |
| prompt=None, | |
| prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), | |
| prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, | |
| dtype=self.unet.dtype), | |
| negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, | |
| dtype=self.unet.dtype), | |
| negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, | |
| dtype=self.unet.dtype), | |
| negative_prompt=None, | |
| # negative_prompt=gen_config.negative_prompt, | |
| height=gen_config.height, | |
| width=gen_config.width, | |
| num_inference_steps=gen_config.num_inference_steps, | |
| guidance_scale=gen_config.guidance_scale, | |
| latents=gen_config.latents, | |
| **extra | |
| ).images[0] | |
| elif self.is_auraflow: | |
| pipeline: AuraFlowPipeline = pipeline | |
| img = pipeline( | |
| prompt=None, | |
| prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), | |
| prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, | |
| dtype=self.unet.dtype), | |
| negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, | |
| dtype=self.unet.dtype), | |
| negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, | |
| dtype=self.unet.dtype), | |
| negative_prompt=None, | |
| # negative_prompt=gen_config.negative_prompt, | |
| height=gen_config.height, | |
| width=gen_config.width, | |
| num_inference_steps=gen_config.num_inference_steps, | |
| guidance_scale=gen_config.guidance_scale, | |
| latents=gen_config.latents, | |
| **extra | |
| ).images[0] | |
| else: | |
| img = pipeline( | |
| # prompt=gen_config.prompt, | |
| prompt_embeds=conditional_embeds.text_embeds, | |
| negative_prompt_embeds=unconditional_embeds.text_embeds, | |
| # negative_prompt=gen_config.negative_prompt, | |
| height=gen_config.height, | |
| width=gen_config.width, | |
| num_inference_steps=gen_config.num_inference_steps, | |
| guidance_scale=gen_config.guidance_scale, | |
| latents=gen_config.latents, | |
| **extra | |
| ).images[0] | |
| if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: | |
| # slide off just the last 1280 on the last dim as refiner does not use first text encoder | |
| # todo, should we just use the Text encoder for the refiner? Fine tuned versions will differ | |
| refiner_text_embeds = conditional_embeds.text_embeds[:, :, -1280:] | |
| refiner_unconditional_text_embeds = unconditional_embeds.text_embeds[:, :, -1280:] | |
| # run through refiner | |
| img = refiner_pipeline( | |
| # prompt=gen_config.prompt, | |
| # prompt_2=gen_config.prompt_2, | |
| # slice these as it does not use both text encoders | |
| # height=gen_config.height, | |
| # width=gen_config.width, | |
| prompt_embeds=refiner_text_embeds, | |
| pooled_prompt_embeds=conditional_embeds.pooled_embeds, | |
| negative_prompt_embeds=refiner_unconditional_text_embeds, | |
| negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, | |
| num_inference_steps=gen_config.num_inference_steps, | |
| guidance_scale=gen_config.guidance_scale, | |
| guidance_rescale=grs, | |
| denoising_start=gen_config.refiner_start_at, | |
| denoising_end=gen_config.num_inference_steps, | |
| image=img.unsqueeze(0) | |
| ).images[0] | |
| gen_config.save_image(img, i) | |
| gen_config.log_image(img, i) | |
| if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter): | |
| self.adapter.clear_memory() | |
| # clear pipeline and cache to reduce vram usage | |
| del pipeline | |
| if refiner_pipeline is not None: | |
| del refiner_pipeline | |
| torch.cuda.empty_cache() | |
| # restore training state | |
| torch.set_rng_state(rng_state) | |
| if cuda_rng_state is not None: | |
| torch.cuda.set_rng_state(cuda_rng_state) | |
| self.restore_device_state() | |
| if self.network is not None: | |
| self.network.train() | |
| self.network.multiplier = start_multiplier | |
| self.unet.to(self.device_torch, dtype=self.torch_dtype) | |
| if network.is_merged_in: | |
| network.merge_out(merge_multiplier) | |
| # self.tokenizer.to(original_device_dict['tokenizer']) | |
| # refuse loras | |
| if self.model_config.assistant_lora_path is not None: | |
| print("Loading assistant lora") | |
| if self.invert_assistant_lora: | |
| self.assistant_lora.is_active = False | |
| # move weights off the device | |
| self.assistant_lora.force_to('cpu', self.torch_dtype) | |
| else: | |
| self.assistant_lora.is_active = True | |
| if self.model_config.inference_lora_path is not None: | |
| print("Unloading inference lora") | |
| self.assistant_lora.is_active = False | |
| # move weights off the device | |
| self.assistant_lora.force_to('cpu', self.torch_dtype) | |
| flush() | |
| def get_latent_noise( | |
| self, | |
| height=None, | |
| width=None, | |
| pixel_height=None, | |
| pixel_width=None, | |
| batch_size=1, | |
| noise_offset=0.0, | |
| ): | |
| VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) | |
| if height is None and pixel_height is None: | |
| raise ValueError("height or pixel_height must be specified") | |
| if width is None and pixel_width is None: | |
| raise ValueError("width or pixel_width must be specified") | |
| if height is None: | |
| height = pixel_height // VAE_SCALE_FACTOR | |
| if width is None: | |
| width = pixel_width // VAE_SCALE_FACTOR | |
| num_channels = self.unet.config['in_channels'] | |
| if self.is_flux: | |
| # has 64 channels in for some reason | |
| num_channels = 16 | |
| noise = torch.randn( | |
| ( | |
| batch_size, | |
| num_channels, | |
| height, | |
| width, | |
| ), | |
| device=self.unet.device, | |
| ) | |
| noise = apply_noise_offset(noise, noise_offset) | |
| return noise | |
| def get_time_ids_from_latents(self, latents: torch.Tensor, requires_aesthetic_score=False): | |
| VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) | |
| if self.is_xl: | |
| bs, ch, h, w = list(latents.shape) | |
| height = h * VAE_SCALE_FACTOR | |
| width = w * VAE_SCALE_FACTOR | |
| dtype = latents.dtype | |
| # just do it without any cropping nonsense | |
| target_size = (height, width) | |
| original_size = (height, width) | |
| crops_coords_top_left = (0, 0) | |
| if requires_aesthetic_score: | |
| # refiner | |
| # https://huggingface.co/papers/2307.01952 | |
| aesthetic_score = 6.0 # simulate one | |
| add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) | |
| else: | |
| add_time_ids = list(original_size + crops_coords_top_left + target_size) | |
| add_time_ids = torch.tensor([add_time_ids]) | |
| add_time_ids = add_time_ids.to(latents.device, dtype=dtype) | |
| batch_time_ids = torch.cat( | |
| [add_time_ids for _ in range(bs)] | |
| ) | |
| return batch_time_ids | |
| else: | |
| return None | |
| def add_noise( | |
| self, | |
| original_samples: torch.FloatTensor, | |
| noise: torch.FloatTensor, | |
| timesteps: torch.IntTensor | |
| ) -> torch.FloatTensor: | |
| original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0) | |
| noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) | |
| timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) | |
| if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks): | |
| timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks) | |
| noisy_latents_chunks = [] | |
| for idx in range(original_samples.shape[0]): | |
| noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx], | |
| timesteps_chunks[idx]) | |
| noisy_latents_chunks.append(noisy_latents) | |
| noisy_latents = torch.cat(noisy_latents_chunks, dim=0) | |
| return noisy_latents | |
| def predict_noise( | |
| self, | |
| latents: torch.Tensor, | |
| text_embeddings: Union[PromptEmbeds, None] = None, | |
| timestep: Union[int, torch.Tensor] = 1, | |
| guidance_scale=7.5, | |
| guidance_rescale=0, | |
| add_time_ids=None, | |
| conditional_embeddings: Union[PromptEmbeds, None] = None, | |
| unconditional_embeddings: Union[PromptEmbeds, None] = None, | |
| is_input_scaled=False, | |
| detach_unconditional=False, | |
| rescale_cfg=None, | |
| return_conditional_pred=False, | |
| guidance_embedding_scale=1.0, | |
| **kwargs, | |
| ): | |
| conditional_pred = None | |
| # get the embeddings | |
| if text_embeddings is None and conditional_embeddings is None: | |
| raise ValueError("Either text_embeddings or conditional_embeddings must be specified") | |
| if text_embeddings is None and unconditional_embeddings is not None: | |
| text_embeddings = concat_prompt_embeds([ | |
| unconditional_embeddings, # negative embedding | |
| conditional_embeddings, # positive embedding | |
| ]) | |
| elif text_embeddings is None and conditional_embeddings is not None: | |
| # not doing cfg | |
| text_embeddings = conditional_embeddings | |
| # CFG is comparing neg and positive, if we have concatenated embeddings | |
| # then we are doing it, otherwise we are not and takes half the time. | |
| do_classifier_free_guidance = True | |
| # check if batch size of embeddings matches batch size of latents | |
| if latents.shape[0] == text_embeddings.text_embeds.shape[0]: | |
| do_classifier_free_guidance = False | |
| elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]: | |
| raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings") | |
| latents = latents.to(self.device_torch) | |
| text_embeddings = text_embeddings.to(self.device_torch) | |
| timestep = timestep.to(self.device_torch) | |
| # if timestep is zero dim, unsqueeze it | |
| if len(timestep.shape) == 0: | |
| timestep = timestep.unsqueeze(0) | |
| # if we only have 1 timestep, we can just use the same timestep for all | |
| if timestep.shape[0] == 1 and latents.shape[0] > 1: | |
| # check if it is rank 1 or 2 | |
| if len(timestep.shape) == 1: | |
| timestep = timestep.repeat(latents.shape[0]) | |
| else: | |
| timestep = timestep.repeat(latents.shape[0], 0) | |
| # handle t2i adapters | |
| if 'down_intrablock_additional_residuals' in kwargs: | |
| # go through each item and concat if doing cfg and it doesnt have the same shape | |
| for idx, item in enumerate(kwargs['down_intrablock_additional_residuals']): | |
| if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: | |
| kwargs['down_intrablock_additional_residuals'][idx] = torch.cat([item] * 2, dim=0) | |
| # handle controlnet | |
| if 'down_block_additional_residuals' in kwargs and 'mid_block_additional_residual' in kwargs: | |
| # go through each item and concat if doing cfg and it doesnt have the same shape | |
| for idx, item in enumerate(kwargs['down_block_additional_residuals']): | |
| if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: | |
| kwargs['down_block_additional_residuals'][idx] = torch.cat([item] * 2, dim=0) | |
| for idx, item in enumerate(kwargs['mid_block_additional_residual']): | |
| if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: | |
| kwargs['mid_block_additional_residual'][idx] = torch.cat([item] * 2, dim=0) | |
| def scale_model_input(model_input, timestep_tensor): | |
| if is_input_scaled: | |
| return model_input | |
| mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) | |
| timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0) | |
| out_chunks = [] | |
| # unsqueeze if timestep is zero dim | |
| for idx in range(model_input.shape[0]): | |
| # if scheduler has step_index | |
| if hasattr(self.noise_scheduler, '_step_index'): | |
| self.noise_scheduler._step_index = None | |
| out_chunks.append( | |
| self.noise_scheduler.scale_model_input(mi_chunks[idx], timestep_chunks[idx]) | |
| ) | |
| return torch.cat(out_chunks, dim=0) | |
| if self.is_xl: | |
| with torch.no_grad(): | |
| # 16, 6 for bs of 4 | |
| if add_time_ids is None: | |
| add_time_ids = self.get_time_ids_from_latents(latents) | |
| if do_classifier_free_guidance: | |
| # todo check this with larget batches | |
| add_time_ids = torch.cat([add_time_ids] * 2) | |
| if do_classifier_free_guidance: | |
| latent_model_input = torch.cat([latents] * 2) | |
| timestep = torch.cat([timestep] * 2) | |
| else: | |
| latent_model_input = latents | |
| latent_model_input = scale_model_input(latent_model_input, timestep) | |
| added_cond_kwargs = { | |
| # todo can we zero here the second text encoder? or match a blank string? | |
| "text_embeds": text_embeddings.pooled_embeds, | |
| "time_ids": add_time_ids, | |
| } | |
| if self.model_config.refiner_name_or_path is not None: | |
| # we have the refiner on the second half of everything. Do Both | |
| if do_classifier_free_guidance: | |
| raise ValueError("Refiner is not supported with classifier free guidance") | |
| if self.unet.training: | |
| input_chunks = torch.chunk(latent_model_input, 2, dim=0) | |
| timestep_chunks = torch.chunk(timestep, 2, dim=0) | |
| added_cond_kwargs_chunked = { | |
| "text_embeds": torch.chunk(text_embeddings.pooled_embeds, 2, dim=0), | |
| "time_ids": torch.chunk(add_time_ids, 2, dim=0), | |
| } | |
| text_embeds_chunks = torch.chunk(text_embeddings.text_embeds, 2, dim=0) | |
| # predict the noise residual | |
| base_pred = self.unet( | |
| input_chunks[0], | |
| timestep_chunks[0], | |
| encoder_hidden_states=text_embeds_chunks[0], | |
| added_cond_kwargs={ | |
| "text_embeds": added_cond_kwargs_chunked['text_embeds'][0], | |
| "time_ids": added_cond_kwargs_chunked['time_ids'][0], | |
| }, | |
| **kwargs, | |
| ).sample | |
| refiner_pred = self.refiner_unet( | |
| input_chunks[1], | |
| timestep_chunks[1], | |
| encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:], | |
| # just use the first second text encoder | |
| added_cond_kwargs={ | |
| "text_embeds": added_cond_kwargs_chunked['text_embeds'][1], | |
| # "time_ids": added_cond_kwargs_chunked['time_ids'][1], | |
| "time_ids": self.get_time_ids_from_latents(input_chunks[1], requires_aesthetic_score=True), | |
| }, | |
| **kwargs, | |
| ).sample | |
| noise_pred = torch.cat([base_pred, refiner_pred], dim=0) | |
| else: | |
| noise_pred = self.refiner_unet( | |
| latent_model_input, | |
| timestep, | |
| encoder_hidden_states=text_embeddings.text_embeds[:, :, -1280:], | |
| # just use the first second text encoder | |
| added_cond_kwargs={ | |
| "text_embeds": text_embeddings.pooled_embeds, | |
| "time_ids": self.get_time_ids_from_latents(latent_model_input, | |
| requires_aesthetic_score=True), | |
| }, | |
| **kwargs, | |
| ).sample | |
| else: | |
| # predict the noise residual | |
| noise_pred = self.unet( | |
| latent_model_input.to(self.device_torch, self.torch_dtype), | |
| timestep, | |
| encoder_hidden_states=text_embeddings.text_embeds, | |
| added_cond_kwargs=added_cond_kwargs, | |
| **kwargs, | |
| ).sample | |
| conditional_pred = noise_pred | |
| if do_classifier_free_guidance: | |
| # perform guidance | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| conditional_pred = noise_pred_text | |
| noise_pred = noise_pred_uncond + guidance_scale * ( | |
| noise_pred_text - noise_pred_uncond | |
| ) | |
| # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 | |
| if guidance_rescale > 0.0: | |
| # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf | |
| noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) | |
| else: | |
| with torch.no_grad(): | |
| if do_classifier_free_guidance: | |
| # if we are doing classifier free guidance, need to double up | |
| latent_model_input = torch.cat([latents] * 2, dim=0) | |
| timestep = torch.cat([timestep] * 2) | |
| else: | |
| latent_model_input = latents | |
| latent_model_input = scale_model_input(latent_model_input, timestep) | |
| # check if we need to concat timesteps | |
| if isinstance(timestep, torch.Tensor) and len(timestep.shape) > 1: | |
| ts_bs = timestep.shape[0] | |
| if ts_bs != latent_model_input.shape[0]: | |
| if ts_bs == 1: | |
| timestep = torch.cat([timestep] * latent_model_input.shape[0]) | |
| elif ts_bs * 2 == latent_model_input.shape[0]: | |
| timestep = torch.cat([timestep] * 2, dim=0) | |
| else: | |
| raise ValueError( | |
| f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}") | |
| # predict the noise residual | |
| if self.is_pixart: | |
| VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) | |
| batch_size, ch, h, w = list(latents.shape) | |
| height = h * VAE_SCALE_FACTOR | |
| width = w * VAE_SCALE_FACTOR | |
| if self.pipeline.transformer.config.sample_size == 256: | |
| aspect_ratio_bin = ASPECT_RATIO_2048_BIN | |
| elif self.pipeline.transformer.config.sample_size == 128: | |
| aspect_ratio_bin = ASPECT_RATIO_1024_BIN | |
| elif self.pipeline.transformer.config.sample_size == 64: | |
| aspect_ratio_bin = ASPECT_RATIO_512_BIN | |
| elif self.pipeline.transformer.config.sample_size == 32: | |
| aspect_ratio_bin = ASPECT_RATIO_256_BIN | |
| else: | |
| raise ValueError(f"Invalid sample size: {self.pipeline.transformer.config.sample_size}") | |
| orig_height, orig_width = height, width | |
| height, width = self.pipeline.image_processor.classify_height_width_bin(height, width, | |
| ratios=aspect_ratio_bin) | |
| added_cond_kwargs = {"resolution": None, "aspect_ratio": None} | |
| if self.unet.config.sample_size == 128 or ( | |
| self.vae_scale_factor == 16 and self.unet.config.sample_size == 64): | |
| resolution = torch.tensor([height, width]).repeat(batch_size, 1) | |
| aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size, 1) | |
| resolution = resolution.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch) | |
| aspect_ratio = aspect_ratio.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch) | |
| if do_classifier_free_guidance: | |
| resolution = torch.cat([resolution, resolution], dim=0) | |
| aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) | |
| added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} | |
| noise_pred = self.unet( | |
| latent_model_input.to(self.device_torch, self.torch_dtype), | |
| encoder_hidden_states=text_embeddings.text_embeds, | |
| encoder_attention_mask=text_embeddings.attention_mask, | |
| timestep=timestep, | |
| added_cond_kwargs=added_cond_kwargs, | |
| return_dict=False, | |
| **kwargs | |
| )[0] | |
| # learned sigma | |
| if self.unet.config.out_channels // 2 == self.unet.config.in_channels: | |
| noise_pred = noise_pred.chunk(2, dim=1)[0] | |
| else: | |
| noise_pred = noise_pred | |
| else: | |
| if self.unet.device != self.device_torch: | |
| self.unet.to(self.device_torch) | |
| if self.unet.dtype != self.torch_dtype: | |
| self.unet = self.unet.to(dtype=self.torch_dtype) | |
| if self.is_flux: | |
| with torch.no_grad(): | |
| bs, c, h, w = latent_model_input.shape | |
| latent_model_input_packed = rearrange( | |
| latent_model_input, | |
| "b c (h ph) (w pw) -> b (h w) (c ph pw)", | |
| ph=2, | |
| pw=2 | |
| ) | |
| img_ids = torch.zeros(h // 2, w // 2, 3) | |
| img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] | |
| img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] | |
| img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs).to(self.device_torch) | |
| txt_ids = torch.zeros(bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch) | |
| # # handle guidance | |
| if self.unet.config.guidance_embeds: | |
| if isinstance(guidance_scale, list): | |
| guidance = torch.tensor(guidance_scale, device=self.device_torch) | |
| else: | |
| guidance = torch.tensor([guidance_scale], device=self.device_torch) | |
| guidance = guidance.expand(latents.shape[0]) | |
| else: | |
| guidance = None | |
| cast_dtype = self.unet.dtype | |
| # with torch.amp.autocast(device_type='cuda', dtype=cast_dtype): | |
| noise_pred = self.unet( | |
| hidden_states=latent_model_input_packed.to(self.device_torch, cast_dtype), # [1, 4096, 64] | |
| # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) | |
| # todo make sure this doesnt change | |
| timestep=timestep / 1000, # timestep is 1000 scale | |
| encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, cast_dtype), | |
| # [1, 512, 4096] | |
| pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, cast_dtype), # [1, 768] | |
| txt_ids=txt_ids, # [1, 512, 3] | |
| img_ids=img_ids, # [1, 4096, 3] | |
| guidance=guidance, | |
| return_dict=False, | |
| **kwargs, | |
| )[0] | |
| if isinstance(noise_pred, QTensor): | |
| noise_pred = noise_pred.dequantize() | |
| noise_pred = rearrange( | |
| noise_pred, | |
| "b (h w) (c ph pw) -> b c (h ph) (w pw)", | |
| h=latent_model_input.shape[2] // 2, | |
| w=latent_model_input.shape[3] // 2, | |
| ph=2, | |
| pw=2, | |
| c=latent_model_input.shape[1], | |
| ) | |
| elif self.is_v3: | |
| noise_pred = self.unet( | |
| hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), | |
| timestep=timestep, | |
| encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), | |
| pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype), | |
| **kwargs, | |
| ).sample | |
| if isinstance(noise_pred, QTensor): | |
| noise_pred = noise_pred.dequantize() | |
| elif self.is_auraflow: | |
| # aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image | |
| # broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
| t = torch.tensor([timestep / 1000]).expand(latent_model_input.shape[0]) | |
| t = t.to(self.device_torch, self.torch_dtype) | |
| noise_pred = self.unet( | |
| latent_model_input, | |
| encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), | |
| timestep=t, | |
| return_dict=False, | |
| )[0] | |
| else: | |
| noise_pred = self.unet( | |
| latent_model_input.to(self.device_torch, self.torch_dtype), | |
| timestep=timestep, | |
| encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), | |
| **kwargs, | |
| ).sample | |
| conditional_pred = noise_pred | |
| if do_classifier_free_guidance: | |
| # perform guidance | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2, dim=0) | |
| conditional_pred = noise_pred_text | |
| if detach_unconditional: | |
| noise_pred_uncond = noise_pred_uncond.detach() | |
| noise_pred = noise_pred_uncond + guidance_scale * ( | |
| noise_pred_text - noise_pred_uncond | |
| ) | |
| if rescale_cfg is not None and rescale_cfg != guidance_scale: | |
| with torch.no_grad(): | |
| # do cfg at the target rescale so we can match it | |
| target_pred_mean_std = noise_pred_uncond + rescale_cfg * ( | |
| noise_pred_text - noise_pred_uncond | |
| ) | |
| target_mean = target_pred_mean_std.mean([1, 2, 3], keepdim=True).detach() | |
| target_std = target_pred_mean_std.std([1, 2, 3], keepdim=True).detach() | |
| pred_mean = noise_pred.mean([1, 2, 3], keepdim=True).detach() | |
| pred_std = noise_pred.std([1, 2, 3], keepdim=True).detach() | |
| # match the mean and std | |
| noise_pred = (noise_pred - pred_mean) / pred_std | |
| noise_pred = (noise_pred * target_std) + target_mean | |
| # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 | |
| if guidance_rescale > 0.0: | |
| # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf | |
| noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) | |
| if return_conditional_pred: | |
| return noise_pred, conditional_pred | |
| return noise_pred | |
| def step_scheduler(self, model_input, latent_input, timestep_tensor, noise_scheduler=None): | |
| if noise_scheduler is None: | |
| noise_scheduler = self.noise_scheduler | |
| # // sometimes they are on the wrong device, no idea why | |
| if isinstance(noise_scheduler, DDPMScheduler) or isinstance(noise_scheduler, LCMScheduler): | |
| try: | |
| noise_scheduler.betas = noise_scheduler.betas.to(self.device_torch) | |
| noise_scheduler.alphas = noise_scheduler.alphas.to(self.device_torch) | |
| noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.to(self.device_torch) | |
| except Exception as e: | |
| pass | |
| mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) | |
| latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0) | |
| timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0) | |
| out_chunks = [] | |
| if len(timestep_chunks) == 1 and len(mi_chunks) > 1: | |
| # expand timestep to match | |
| timestep_chunks = timestep_chunks * len(mi_chunks) | |
| for idx in range(model_input.shape[0]): | |
| # Reset it so it is unique for the | |
| if hasattr(noise_scheduler, '_step_index'): | |
| noise_scheduler._step_index = None | |
| if hasattr(noise_scheduler, 'is_scale_input_called'): | |
| noise_scheduler.is_scale_input_called = True | |
| out_chunks.append( | |
| noise_scheduler.step(mi_chunks[idx], timestep_chunks[idx], latent_chunks[idx], return_dict=False)[ | |
| 0] | |
| ) | |
| return torch.cat(out_chunks, dim=0) | |
| # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746 | |
| def diffuse_some_steps( | |
| self, | |
| latents: torch.FloatTensor, | |
| text_embeddings: PromptEmbeds, | |
| total_timesteps: int = 1000, | |
| start_timesteps=0, | |
| guidance_scale=1, | |
| add_time_ids=None, | |
| bleed_ratio: float = 0.5, | |
| bleed_latents: torch.FloatTensor = None, | |
| is_input_scaled=False, | |
| return_first_prediction=False, | |
| **kwargs, | |
| ): | |
| timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps] | |
| first_prediction = None | |
| for timestep in tqdm(timesteps_to_run, leave=False): | |
| timestep = timestep.unsqueeze_(0) | |
| noise_pred, conditional_pred = self.predict_noise( | |
| latents, | |
| text_embeddings, | |
| timestep, | |
| guidance_scale=guidance_scale, | |
| add_time_ids=add_time_ids, | |
| is_input_scaled=is_input_scaled, | |
| return_conditional_pred=True, | |
| **kwargs, | |
| ) | |
| # some schedulers need to run separately, so do that. (euler for example) | |
| if return_first_prediction and first_prediction is None: | |
| first_prediction = conditional_pred | |
| latents = self.step_scheduler(noise_pred, latents, timestep) | |
| # if not last step, and bleeding, bleed in some latents | |
| if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]: | |
| latents = (latents * (1 - bleed_ratio)) + (bleed_latents * bleed_ratio) | |
| # only skip first scaling | |
| is_input_scaled = False | |
| # return latents_steps | |
| if return_first_prediction: | |
| return latents, first_prediction | |
| return latents | |
| def encode_prompt( | |
| self, | |
| prompt, | |
| prompt2=None, | |
| num_images_per_prompt=1, | |
| force_all=False, | |
| long_prompts=False, | |
| max_length=None, | |
| dropout_prob=0.0, | |
| ) -> PromptEmbeds: | |
| # sd1.5 embeddings are (bs, 77, 768) | |
| prompt = prompt | |
| # if it is not a list, make it one | |
| if not isinstance(prompt, list): | |
| prompt = [prompt] | |
| if prompt2 is not None and not isinstance(prompt2, list): | |
| prompt2 = [prompt2] | |
| if self.is_xl: | |
| # todo make this a config | |
| # 50% chance to use an encoder anyway even if it is disabled | |
| # allows the other TE to compensate for the disabled one | |
| # use_encoder_1 = self.use_text_encoder_1 or force_all or random.random() > 0.5 | |
| # use_encoder_2 = self.use_text_encoder_2 or force_all or random.random() > 0.5 | |
| use_encoder_1 = True | |
| use_encoder_2 = True | |
| return PromptEmbeds( | |
| train_tools.encode_prompts_xl( | |
| self.tokenizer, | |
| self.text_encoder, | |
| prompt, | |
| prompt2, | |
| num_images_per_prompt=num_images_per_prompt, | |
| use_text_encoder_1=use_encoder_1, | |
| use_text_encoder_2=use_encoder_2, | |
| truncate=not long_prompts, | |
| max_length=max_length, | |
| dropout_prob=dropout_prob, | |
| ) | |
| ) | |
| if self.is_v3: | |
| return PromptEmbeds( | |
| train_tools.encode_prompts_sd3( | |
| self.tokenizer, | |
| self.text_encoder, | |
| prompt, | |
| num_images_per_prompt=num_images_per_prompt, | |
| truncate=not long_prompts, | |
| max_length=max_length, | |
| dropout_prob=dropout_prob, | |
| pipeline=self.pipeline, | |
| ) | |
| ) | |
| elif self.is_pixart: | |
| embeds, attention_mask = train_tools.encode_prompts_pixart( | |
| self.tokenizer, | |
| self.text_encoder, | |
| prompt, | |
| truncate=not long_prompts, | |
| max_length=300 if self.model_config.is_pixart_sigma else 120, | |
| dropout_prob=dropout_prob | |
| ) | |
| return PromptEmbeds( | |
| embeds, | |
| attention_mask=attention_mask, | |
| ) | |
| elif self.is_auraflow: | |
| embeds, attention_mask = train_tools.encode_prompts_auraflow( | |
| self.tokenizer, | |
| self.text_encoder, | |
| prompt, | |
| truncate=not long_prompts, | |
| max_length=256, | |
| dropout_prob=dropout_prob | |
| ) | |
| return PromptEmbeds( | |
| embeds, | |
| attention_mask=attention_mask, # not used | |
| ) | |
| elif self.is_flux: | |
| prompt_embeds, pooled_prompt_embeds = train_tools.encode_prompts_flux( | |
| self.tokenizer, # list | |
| self.text_encoder, # list | |
| prompt, | |
| truncate=not long_prompts, | |
| max_length=512, | |
| dropout_prob=dropout_prob, | |
| attn_mask=self.model_config.attn_masking | |
| ) | |
| pe = PromptEmbeds( | |
| prompt_embeds | |
| ) | |
| pe.pooled_embeds = pooled_prompt_embeds | |
| return pe | |
| elif isinstance(self.text_encoder, T5EncoderModel): | |
| embeds, attention_mask = train_tools.encode_prompts_pixart( | |
| self.tokenizer, | |
| self.text_encoder, | |
| prompt, | |
| truncate=not long_prompts, | |
| max_length=256, | |
| dropout_prob=dropout_prob | |
| ) | |
| # just mask the attention mask | |
| prompt_attention_mask = attention_mask.unsqueeze(-1).expand(embeds.shape) | |
| embeds = embeds * prompt_attention_mask.to(dtype=embeds.dtype, device=embeds.device) | |
| return PromptEmbeds( | |
| embeds, | |
| # do we want attn mask here? | |
| # attention_mask=attention_mask, | |
| ) | |
| else: | |
| return PromptEmbeds( | |
| train_tools.encode_prompts( | |
| self.tokenizer, | |
| self.text_encoder, | |
| prompt, | |
| truncate=not long_prompts, | |
| max_length=max_length, | |
| dropout_prob=dropout_prob | |
| ) | |
| ) | |
| def encode_images( | |
| self, | |
| image_list: List[torch.Tensor], | |
| device=None, | |
| dtype=None | |
| ): | |
| if device is None: | |
| device = self.vae_device_torch | |
| if dtype is None: | |
| dtype = self.vae_torch_dtype | |
| latent_list = [] | |
| # Move to vae to device if on cpu | |
| if self.vae.device == 'cpu': | |
| self.vae.to(device) | |
| self.vae.eval() | |
| self.vae.requires_grad_(False) | |
| # move to device and dtype | |
| image_list = [image.to(device, dtype=dtype) for image in image_list] | |
| VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) | |
| # resize images if not divisible by 8 | |
| for i in range(len(image_list)): | |
| image = image_list[i] | |
| if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0: | |
| image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR, | |
| image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image) | |
| images = torch.stack(image_list) | |
| if isinstance(self.vae, AutoencoderTiny): | |
| latents = self.vae.encode(images, return_dict=False)[0] | |
| else: | |
| latents = self.vae.encode(images).latent_dist.sample() | |
| shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 | |
| # flux ref https://github.com/black-forest-labs/flux/blob/c23ae247225daba30fbd56058d247cc1b1fc20a3/src/flux/modules/autoencoder.py#L303 | |
| # z = self.scale_factor * (z - self.shift_factor) | |
| latents = self.vae.config['scaling_factor'] * (latents - shift) | |
| latents = latents.to(device, dtype=dtype) | |
| return latents | |
| def decode_latents( | |
| self, | |
| latents: torch.Tensor, | |
| device=None, | |
| dtype=None | |
| ): | |
| if device is None: | |
| device = self.device | |
| if dtype is None: | |
| dtype = self.torch_dtype | |
| # Move to vae to device if on cpu | |
| if self.vae.device == 'cpu': | |
| self.vae.to(self.device) | |
| latents = latents.to(device, dtype=dtype) | |
| latents = (latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor'] | |
| images = self.vae.decode(latents).sample | |
| images = images.to(device, dtype=dtype) | |
| return images | |
| def encode_image_prompt_pairs( | |
| self, | |
| prompt_list: List[str], | |
| image_list: List[torch.Tensor], | |
| device=None, | |
| dtype=None | |
| ): | |
| # todo check image types and expand and rescale as needed | |
| # device and dtype are for outputs | |
| if device is None: | |
| device = self.device | |
| if dtype is None: | |
| dtype = self.torch_dtype | |
| embedding_list = [] | |
| latent_list = [] | |
| # embed the prompts | |
| for prompt in prompt_list: | |
| embedding = self.encode_prompt(prompt).to(self.device_torch, dtype=dtype) | |
| embedding_list.append(embedding) | |
| return embedding_list, latent_list | |
| def get_weight_by_name(self, name): | |
| # weights begin with te{te_num}_ for text encoder | |
| # weights begin with unet_ for unet_ | |
| if name.startswith('te'): | |
| key = name[4:] | |
| # text encoder | |
| te_num = int(name[2]) | |
| if isinstance(self.text_encoder, list): | |
| return self.text_encoder[te_num].state_dict()[key] | |
| else: | |
| return self.text_encoder.state_dict()[key] | |
| elif name.startswith('unet'): | |
| key = name[5:] | |
| # unet | |
| return self.unet.state_dict()[key] | |
| raise ValueError(f"Unknown weight name: {name}") | |
| def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=False): | |
| return inject_trigger_into_prompt( | |
| prompt, | |
| trigger=trigger, | |
| to_replace_list=to_replace_list, | |
| add_if_not_present=add_if_not_present, | |
| ) | |
| def state_dict(self, vae=True, text_encoder=True, unet=True): | |
| state_dict = OrderedDict() | |
| if vae: | |
| for k, v in self.vae.state_dict().items(): | |
| new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}" | |
| state_dict[new_key] = v | |
| if text_encoder: | |
| if isinstance(self.text_encoder, list): | |
| for i, encoder in enumerate(self.text_encoder): | |
| for k, v in encoder.state_dict().items(): | |
| new_key = k if k.startswith( | |
| f"{SD_PREFIX_TEXT_ENCODER}{i}_") else f"{SD_PREFIX_TEXT_ENCODER}{i}_{k}" | |
| state_dict[new_key] = v | |
| else: | |
| for k, v in self.text_encoder.state_dict().items(): | |
| new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER}_") else f"{SD_PREFIX_TEXT_ENCODER}_{k}" | |
| state_dict[new_key] = v | |
| if unet: | |
| for k, v in self.unet.state_dict().items(): | |
| new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}" | |
| state_dict[new_key] = v | |
| return state_dict | |
| def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> \ | |
| OrderedDict[ | |
| str, Parameter]: | |
| named_params: OrderedDict[str, Parameter] = OrderedDict() | |
| if vae: | |
| for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"): | |
| named_params[name] = param | |
| if text_encoder: | |
| if isinstance(self.text_encoder, list): | |
| for i, encoder in enumerate(self.text_encoder): | |
| if self.is_xl and not self.model_config.use_text_encoder_1 and i == 0: | |
| # dont add these params | |
| continue | |
| if self.is_xl and not self.model_config.use_text_encoder_2 and i == 1: | |
| # dont add these params | |
| continue | |
| for name, param in encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}{i}"): | |
| named_params[name] = param | |
| else: | |
| for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"): | |
| named_params[name] = param | |
| if unet: | |
| if self.is_flux: | |
| # Just train the middle 2 blocks of each transformer block | |
| # block_list = [] | |
| # num_transformer_blocks = 2 | |
| # start_block = len(self.unet.transformer_blocks) // 2 - (num_transformer_blocks // 2) | |
| # for i in range(num_transformer_blocks): | |
| # block_list.append(self.unet.transformer_blocks[start_block + i]) | |
| # | |
| # num_single_transformer_blocks = 4 | |
| # start_block = len(self.unet.single_transformer_blocks) // 2 - (num_single_transformer_blocks // 2) | |
| # for i in range(num_single_transformer_blocks): | |
| # block_list.append(self.unet.single_transformer_blocks[start_block + i]) | |
| # | |
| # for block in block_list: | |
| # for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): | |
| # named_params[name] = param | |
| # train the guidance embedding | |
| # if self.unet.config.guidance_embeds: | |
| # transformer: FluxTransformer2DModel = self.unet | |
| # for name, param in transformer.time_text_embed.named_parameters(recurse=True, | |
| # prefix=f"{SD_PREFIX_UNET}"): | |
| # named_params[name] = param | |
| for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, | |
| prefix=f"{SD_PREFIX_UNET}"): | |
| named_params[name] = param | |
| for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, | |
| prefix=f"{SD_PREFIX_UNET}"): | |
| named_params[name] = param | |
| else: | |
| for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): | |
| named_params[name] = param | |
| if refiner: | |
| for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"): | |
| named_params[name] = param | |
| # convert to state dict keys, jsut replace . with _ on keys | |
| if state_dict_keys: | |
| new_named_params = OrderedDict() | |
| for k, v in named_params.items(): | |
| # replace only the first . with an _ | |
| new_key = k.replace('.', '_', 1) | |
| new_named_params[new_key] = v | |
| named_params = new_named_params | |
| return named_params | |
| def save_refiner(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16')): | |
| # load the full refiner since we only train unet | |
| if self.model_config.refiner_name_or_path is None: | |
| raise ValueError("Refiner must be specified to save it") | |
| refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml') | |
| # load the refiner model | |
| dtype = get_torch_dtype(self.dtype) | |
| model_path = self.model_config._original_refiner_name_or_path | |
| if not os.path.exists(model_path) or os.path.isdir(model_path): | |
| # TODO only load unet?? | |
| refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( | |
| model_path, | |
| dtype=dtype, | |
| device='cpu', | |
| # variant="fp16", | |
| use_safetensors=True, | |
| ) | |
| else: | |
| refiner = StableDiffusionXLImg2ImgPipeline.from_single_file( | |
| model_path, | |
| dtype=dtype, | |
| device='cpu', | |
| torch_dtype=self.torch_dtype, | |
| original_config_file=refiner_config_path, | |
| ) | |
| # replace original unet | |
| refiner.unet = self.refiner_unet | |
| flush() | |
| diffusers_state_dict = OrderedDict() | |
| for k, v in refiner.vae.state_dict().items(): | |
| new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}" | |
| diffusers_state_dict[new_key] = v | |
| for k, v in refiner.text_encoder_2.state_dict().items(): | |
| new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER2}_") else f"{SD_PREFIX_TEXT_ENCODER2}_{k}" | |
| diffusers_state_dict[new_key] = v | |
| for k, v in refiner.unet.state_dict().items(): | |
| new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}" | |
| diffusers_state_dict[new_key] = v | |
| converted_state_dict = get_ldm_state_dict_from_diffusers( | |
| diffusers_state_dict, | |
| 'sdxl_refiner', | |
| device='cpu', | |
| dtype=save_dtype | |
| ) | |
| # make sure parent folder exists | |
| os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
| save_file(converted_state_dict, output_file, metadata=meta) | |
| if self.config_file is not None: | |
| output_path_no_ext = os.path.splitext(output_file)[0] | |
| output_config_path = f"{output_path_no_ext}.yaml" | |
| shutil.copyfile(self.config_file, output_config_path) | |
| def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None): | |
| version_string = '1' | |
| if self.is_v2: | |
| version_string = '2' | |
| if self.is_xl: | |
| version_string = 'sdxl' | |
| if self.is_ssd: | |
| # overwrite sdxl because both wil be true here | |
| version_string = 'ssd' | |
| if self.is_ssd and self.is_vega: | |
| version_string = 'vega' | |
| # if output file does not end in .safetensors, then it is a directory and we are | |
| # saving in diffusers format | |
| if not output_file.endswith('.safetensors'): | |
| # diffusers | |
| # if self.is_pixart: | |
| # self.unet.save_pretrained( | |
| # save_directory=output_file, | |
| # safe_serialization=True, | |
| # ) | |
| # else: | |
| if self.is_flux: | |
| # only save the unet | |
| transformer: FluxTransformer2DModel = self.unet | |
| transformer.save_pretrained( | |
| save_directory=os.path.join(output_file, 'transformer'), | |
| safe_serialization=True, | |
| ) | |
| else: | |
| self.pipeline.save_pretrained( | |
| save_directory=output_file, | |
| safe_serialization=True, | |
| ) | |
| # save out meta config | |
| meta_path = os.path.join(output_file, 'aitk_meta.yaml') | |
| with open(meta_path, 'w') as f: | |
| yaml.dump(meta, f) | |
| else: | |
| save_ldm_model_from_diffusers( | |
| sd=self, | |
| output_file=output_file, | |
| meta=meta, | |
| save_dtype=save_dtype, | |
| sd_version=version_string, | |
| ) | |
| if self.config_file is not None: | |
| output_path_no_ext = os.path.splitext(output_file)[0] | |
| output_config_path = f"{output_path_no_ext}.yaml" | |
| shutil.copyfile(self.config_file, output_config_path) | |
| def prepare_optimizer_params( | |
| self, | |
| unet=False, | |
| text_encoder=False, | |
| text_encoder_lr=None, | |
| unet_lr=None, | |
| refiner_lr=None, | |
| refiner=False, | |
| default_lr=1e-6, | |
| ): | |
| # todo maybe only get locon ones? | |
| # not all items are saved, to make it match, we need to match out save mappings | |
| # and not train anything not mapped. Also add learning rate | |
| version = 'sd1' | |
| if self.is_xl: | |
| version = 'sdxl' | |
| if self.is_v2: | |
| version = 'sd2' | |
| mapping_filename = f"stable_diffusion_{version}.json" | |
| mapping_path = os.path.join(KEYMAPS_ROOT, mapping_filename) | |
| with open(mapping_path, 'r') as f: | |
| mapping = json.load(f) | |
| ldm_diffusers_keymap = mapping['ldm_diffusers_keymap'] | |
| trainable_parameters = [] | |
| # we use state dict to find params | |
| if unet: | |
| named_params = self.named_parameters(vae=False, unet=unet, text_encoder=False, state_dict_keys=True) | |
| unet_lr = unet_lr if unet_lr is not None else default_lr | |
| params = [] | |
| if self.is_pixart or self.is_auraflow or self.is_flux: | |
| for param in named_params.values(): | |
| if param.requires_grad: | |
| params.append(param) | |
| else: | |
| for key, diffusers_key in ldm_diffusers_keymap.items(): | |
| if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: | |
| if named_params[diffusers_key].requires_grad: | |
| params.append(named_params[diffusers_key]) | |
| param_data = {"params": params, "lr": unet_lr} | |
| trainable_parameters.append(param_data) | |
| print(f"Found {len(params)} trainable parameter in unet") | |
| if text_encoder: | |
| named_params = self.named_parameters(vae=False, unet=False, text_encoder=text_encoder, state_dict_keys=True) | |
| text_encoder_lr = text_encoder_lr if text_encoder_lr is not None else default_lr | |
| params = [] | |
| for key, diffusers_key in ldm_diffusers_keymap.items(): | |
| if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: | |
| if named_params[diffusers_key].requires_grad: | |
| params.append(named_params[diffusers_key]) | |
| param_data = {"params": params, "lr": text_encoder_lr} | |
| trainable_parameters.append(param_data) | |
| print(f"Found {len(params)} trainable parameter in text encoder") | |
| if refiner: | |
| named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, | |
| state_dict_keys=True) | |
| refiner_lr = refiner_lr if refiner_lr is not None else default_lr | |
| params = [] | |
| for key, diffusers_key in ldm_diffusers_keymap.items(): | |
| diffusers_key = f"refiner_{diffusers_key}" | |
| if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: | |
| if named_params[diffusers_key].requires_grad: | |
| params.append(named_params[diffusers_key]) | |
| param_data = {"params": params, "lr": refiner_lr} | |
| trainable_parameters.append(param_data) | |
| print(f"Found {len(params)} trainable parameter in refiner") | |
| return trainable_parameters | |
| def save_device_state(self): | |
| # saves the current device state for all modules | |
| # this is useful for when we want to alter the state and restore it | |
| if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux: | |
| unet_has_grad = self.unet.proj_out.weight.requires_grad | |
| else: | |
| unet_has_grad = self.unet.conv_in.weight.requires_grad | |
| self.device_state = { | |
| **empty_preset, | |
| 'vae': { | |
| 'training': self.vae.training, | |
| 'device': self.vae.device, | |
| }, | |
| 'unet': { | |
| 'training': self.unet.training, | |
| 'device': self.unet.device, | |
| 'requires_grad': unet_has_grad, | |
| }, | |
| } | |
| if isinstance(self.text_encoder, list): | |
| self.device_state['text_encoder']: List[dict] = [] | |
| for encoder in self.text_encoder: | |
| try: | |
| te_has_grad = encoder.text_model.final_layer_norm.weight.requires_grad | |
| except: | |
| te_has_grad = encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad | |
| self.device_state['text_encoder'].append({ | |
| 'training': encoder.training, | |
| 'device': encoder.device, | |
| # todo there has to be a better way to do this | |
| 'requires_grad': te_has_grad | |
| }) | |
| else: | |
| if isinstance(self.text_encoder, T5EncoderModel) or isinstance(self.text_encoder, UMT5EncoderModel): | |
| te_has_grad = self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad | |
| else: | |
| te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad | |
| self.device_state['text_encoder'] = { | |
| 'training': self.text_encoder.training, | |
| 'device': self.text_encoder.device, | |
| 'requires_grad': te_has_grad | |
| } | |
| if self.adapter is not None: | |
| if isinstance(self.adapter, IPAdapter): | |
| requires_grad = self.adapter.image_proj_model.training | |
| adapter_device = self.unet.device | |
| elif isinstance(self.adapter, T2IAdapter): | |
| requires_grad = self.adapter.adapter.conv_in.weight.requires_grad | |
| adapter_device = self.adapter.device | |
| elif isinstance(self.adapter, ControlNetModel): | |
| requires_grad = self.adapter.conv_in.training | |
| adapter_device = self.adapter.device | |
| elif isinstance(self.adapter, ClipVisionAdapter): | |
| requires_grad = self.adapter.embedder.training | |
| adapter_device = self.adapter.device | |
| elif isinstance(self.adapter, CustomAdapter): | |
| requires_grad = self.adapter.training | |
| adapter_device = self.adapter.device | |
| elif isinstance(self.adapter, ReferenceAdapter): | |
| # todo update this!! | |
| requires_grad = True | |
| adapter_device = self.adapter.device | |
| else: | |
| raise ValueError(f"Unknown adapter type: {type(self.adapter)}") | |
| self.device_state['adapter'] = { | |
| 'training': self.adapter.training, | |
| 'device': adapter_device, | |
| 'requires_grad': requires_grad, | |
| } | |
| if self.refiner_unet is not None: | |
| self.device_state['refiner_unet'] = { | |
| 'training': self.refiner_unet.training, | |
| 'device': self.refiner_unet.device, | |
| 'requires_grad': self.refiner_unet.conv_in.weight.requires_grad, | |
| } | |
| def restore_device_state(self): | |
| # restores the device state for all modules | |
| # this is useful for when we want to alter the state and restore it | |
| if self.device_state is None: | |
| return | |
| self.set_device_state(self.device_state) | |
| self.device_state = None | |
| def set_device_state(self, state): | |
| if state['vae']['training']: | |
| self.vae.train() | |
| else: | |
| self.vae.eval() | |
| self.vae.to(state['vae']['device']) | |
| if state['unet']['training']: | |
| self.unet.train() | |
| else: | |
| self.unet.eval() | |
| self.unet.to(state['unet']['device']) | |
| if state['unet']['requires_grad']: | |
| self.unet.requires_grad_(True) | |
| else: | |
| self.unet.requires_grad_(False) | |
| if isinstance(self.text_encoder, list): | |
| for i, encoder in enumerate(self.text_encoder): | |
| if isinstance(state['text_encoder'], list): | |
| if state['text_encoder'][i]['training']: | |
| encoder.train() | |
| else: | |
| encoder.eval() | |
| encoder.to(state['text_encoder'][i]['device']) | |
| encoder.requires_grad_(state['text_encoder'][i]['requires_grad']) | |
| else: | |
| if state['text_encoder']['training']: | |
| encoder.train() | |
| else: | |
| encoder.eval() | |
| encoder.to(state['text_encoder']['device']) | |
| encoder.requires_grad_(state['text_encoder']['requires_grad']) | |
| else: | |
| if state['text_encoder']['training']: | |
| self.text_encoder.train() | |
| else: | |
| self.text_encoder.eval() | |
| self.text_encoder.to(state['text_encoder']['device']) | |
| self.text_encoder.requires_grad_(state['text_encoder']['requires_grad']) | |
| if self.adapter is not None: | |
| self.adapter.to(state['adapter']['device']) | |
| self.adapter.requires_grad_(state['adapter']['requires_grad']) | |
| if state['adapter']['training']: | |
| self.adapter.train() | |
| else: | |
| self.adapter.eval() | |
| if self.refiner_unet is not None: | |
| self.refiner_unet.to(state['refiner_unet']['device']) | |
| self.refiner_unet.requires_grad_(state['refiner_unet']['requires_grad']) | |
| if state['refiner_unet']['training']: | |
| self.refiner_unet.train() | |
| else: | |
| self.refiner_unet.eval() | |
| flush() | |
| def set_device_state_preset(self, device_state_preset: DeviceStatePreset): | |
| # sets a preset for device state | |
| # save current state first | |
| self.save_device_state() | |
| active_modules = [] | |
| training_modules = [] | |
| if device_state_preset in ['cache_latents']: | |
| active_modules = ['vae'] | |
| if device_state_preset in ['cache_clip']: | |
| active_modules = ['clip'] | |
| if device_state_preset in ['generate']: | |
| active_modules = ['vae', 'unet', 'text_encoder', 'adapter', 'refiner_unet'] | |
| state = copy.deepcopy(empty_preset) | |
| # vae | |
| state['vae'] = { | |
| 'training': 'vae' in training_modules, | |
| 'device': self.vae_device_torch if 'vae' in active_modules else 'cpu', | |
| 'requires_grad': 'vae' in training_modules, | |
| } | |
| # unet | |
| state['unet'] = { | |
| 'training': 'unet' in training_modules, | |
| 'device': self.device_torch if 'unet' in active_modules else 'cpu', | |
| 'requires_grad': 'unet' in training_modules, | |
| } | |
| if self.refiner_unet is not None: | |
| state['refiner_unet'] = { | |
| 'training': 'refiner_unet' in training_modules, | |
| 'device': self.device_torch if 'refiner_unet' in active_modules else 'cpu', | |
| 'requires_grad': 'refiner_unet' in training_modules, | |
| } | |
| # text encoder | |
| if isinstance(self.text_encoder, list): | |
| state['text_encoder'] = [] | |
| for i, encoder in enumerate(self.text_encoder): | |
| state['text_encoder'].append({ | |
| 'training': 'text_encoder' in training_modules, | |
| 'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu', | |
| 'requires_grad': 'text_encoder' in training_modules, | |
| }) | |
| else: | |
| state['text_encoder'] = { | |
| 'training': 'text_encoder' in training_modules, | |
| 'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu', | |
| 'requires_grad': 'text_encoder' in training_modules, | |
| } | |
| if self.adapter is not None: | |
| state['adapter'] = { | |
| 'training': 'adapter' in training_modules, | |
| 'device': self.device_torch if 'adapter' in active_modules else 'cpu', | |
| 'requires_grad': 'adapter' in training_modules, | |
| } | |
| self.set_device_state(state) | |
| def text_encoder_to(self, *args, **kwargs): | |
| if isinstance(self.text_encoder, list): | |
| for encoder in self.text_encoder: | |
| encoder.to(*args, **kwargs) | |
| else: | |
| self.text_encoder.to(*args, **kwargs) | |