| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | import copy |
| | import gc |
| | import logging |
| | import math |
| | import os |
| | import shutil |
| | import warnings |
| | from pathlib import Path |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | import torch.utils.checkpoint |
| | import transformers |
| | from accelerate import Accelerator |
| | from accelerate.logging import get_logger |
| | from accelerate.utils import ProjectConfiguration, set_seed |
| | from huggingface_hub import create_repo, upload_folder |
| | from huggingface_hub.utils import insecure_hashlib |
| | from packaging import version |
| | from peft import LoraConfig |
| | from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict |
| | from PIL import Image |
| | from PIL.ImageOps import exif_transpose |
| | from torch.utils.data import Dataset |
| | from torchvision import transforms |
| | from tqdm.auto import tqdm |
| | from transformers import AutoTokenizer, PretrainedConfig |
| |
|
| | import diffusers |
| | from diffusers import ( |
| | AutoencoderKL, |
| | DDPMScheduler, |
| | DiffusionPipeline, |
| | DPMSolverMultistepScheduler, |
| | StableDiffusionPipeline, |
| | UNet2DConditionModel, |
| | ) |
| | from diffusers.loaders import StableDiffusionLoraLoaderMixin |
| | from diffusers.optimization import get_scheduler |
| | from diffusers.training_utils import ( |
| | _set_state_dict_into_text_encoder, |
| | cast_training_params, |
| | free_memory, |
| | ) |
| | from diffusers.utils import ( |
| | check_min_version, |
| | convert_state_dict_to_diffusers, |
| | convert_unet_state_dict_to_peft, |
| | is_wandb_available, |
| | ) |
| | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card |
| | from diffusers.utils.import_utils import is_xformers_available |
| | from diffusers.utils.torch_utils import is_compiled_module |
| |
|
| |
|
| | if is_wandb_available(): |
| | import wandb |
| |
|
| | |
| | check_min_version("0.35.0.dev0") |
| |
|
| | logger = get_logger(__name__) |
| |
|
| |
|
| | def save_model_card( |
| | repo_id: str, |
| | images=None, |
| | base_model=str, |
| | train_text_encoder=False, |
| | prompt=str, |
| | repo_folder=None, |
| | pipeline: DiffusionPipeline = None, |
| | ): |
| | img_str = "" |
| | for i, image in enumerate(images): |
| | image.save(os.path.join(repo_folder, f"image_{i}.png")) |
| | img_str += f"\n" |
| |
|
| | model_description = f""" |
| | # LoRA DreamBooth - {repo_id} |
| | |
| | These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n |
| | {img_str} |
| | |
| | LoRA for the text encoder was enabled: {train_text_encoder}. |
| | """ |
| | model_card = load_or_create_model_card( |
| | repo_id_or_path=repo_id, |
| | from_training=True, |
| | license="creativeml-openrail-m", |
| | base_model=base_model, |
| | prompt=prompt, |
| | model_description=model_description, |
| | inference=True, |
| | ) |
| | tags = ["text-to-image", "diffusers", "lora", "diffusers-training"] |
| | if isinstance(pipeline, StableDiffusionPipeline): |
| | tags.extend(["stable-diffusion", "stable-diffusion-diffusers"]) |
| | else: |
| | tags.extend(["if", "if-diffusers"]) |
| | model_card = populate_model_card(model_card, tags=tags) |
| |
|
| | model_card.save(os.path.join(repo_folder, "README.md")) |
| |
|
| |
|
| | def log_validation( |
| | pipeline, |
| | args, |
| | accelerator, |
| | pipeline_args, |
| | epoch, |
| | torch_dtype, |
| | is_final_validation=False, |
| | ): |
| | logger.info( |
| | f"Running validation... \n Generating {args.num_validation_images} images with prompt:" |
| | f" {args.validation_prompt}." |
| | ) |
| | |
| | scheduler_args = {} |
| |
|
| | if "variance_type" in pipeline.scheduler.config: |
| | variance_type = pipeline.scheduler.config.variance_type |
| |
|
| | if variance_type in ["learned", "learned_range"]: |
| | variance_type = "fixed_small" |
| |
|
| | scheduler_args["variance_type"] = variance_type |
| |
|
| | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) |
| |
|
| | pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) |
| | pipeline.set_progress_bar_config(disable=True) |
| |
|
| | |
| | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None |
| |
|
| | if args.validation_images is None: |
| | images = [] |
| | for _ in range(args.num_validation_images): |
| | with torch.amp.autocast(accelerator.device.type): |
| | image = pipeline(**pipeline_args, generator=generator).images[0] |
| | images.append(image) |
| | else: |
| | images = [] |
| | for image in args.validation_images: |
| | image = Image.open(image) |
| | with torch.amp.autocast(accelerator.device.type): |
| | image = pipeline(**pipeline_args, image=image, generator=generator).images[0] |
| | images.append(image) |
| |
|
| | for tracker in accelerator.trackers: |
| | phase_name = "test" if is_final_validation else "validation" |
| | if tracker.name == "tensorboard": |
| | np_images = np.stack([np.asarray(img) for img in images]) |
| | tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") |
| | if tracker.name == "wandb": |
| | tracker.log( |
| | { |
| | phase_name: [ |
| | wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) |
| | ] |
| | } |
| | ) |
| |
|
| | del pipeline |
| | free_memory() |
| |
|
| | return images |
| |
|
| |
|
| | def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): |
| | text_encoder_config = PretrainedConfig.from_pretrained( |
| | pretrained_model_name_or_path, |
| | subfolder="text_encoder", |
| | revision=revision, |
| | ) |
| | model_class = text_encoder_config.architectures[0] |
| |
|
| | if model_class == "CLIPTextModel": |
| | from transformers import CLIPTextModel |
| |
|
| | return CLIPTextModel |
| | elif model_class == "RobertaSeriesModelWithTransformation": |
| | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation |
| |
|
| | return RobertaSeriesModelWithTransformation |
| | elif model_class == "T5EncoderModel": |
| | from transformers import T5EncoderModel |
| |
|
| | return T5EncoderModel |
| | else: |
| | raise ValueError(f"{model_class} is not supported.") |
| |
|
| |
|
| | def parse_args(input_args=None): |
| | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
| | parser.add_argument( |
| | "--pretrained_model_name_or_path", |
| | type=str, |
| | default=None, |
| | required=True, |
| | help="Path to pretrained model or model identifier from huggingface.co/models.", |
| | ) |
| | parser.add_argument( |
| | "--revision", |
| | type=str, |
| | default=None, |
| | required=False, |
| | help="Revision of pretrained model identifier from huggingface.co/models.", |
| | ) |
| | parser.add_argument( |
| | "--variant", |
| | type=str, |
| | default=None, |
| | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", |
| | ) |
| | parser.add_argument( |
| | "--tokenizer_name", |
| | type=str, |
| | default=None, |
| | help="Pretrained tokenizer name or path if not the same as model_name", |
| | ) |
| | parser.add_argument( |
| | "--instance_data_dir", |
| | type=str, |
| | default=None, |
| | required=True, |
| | help="A folder containing the training data of instance images.", |
| | ) |
| | parser.add_argument( |
| | "--class_data_dir", |
| | type=str, |
| | default=None, |
| | required=False, |
| | help="A folder containing the training data of class images.", |
| | ) |
| | parser.add_argument( |
| | "--instance_prompt", |
| | type=str, |
| | default=None, |
| | required=True, |
| | help="The prompt with identifier specifying the instance", |
| | ) |
| | parser.add_argument( |
| | "--class_prompt", |
| | type=str, |
| | default=None, |
| | help="The prompt to specify images in the same class as provided instance images.", |
| | ) |
| | parser.add_argument( |
| | "--validation_prompt", |
| | type=str, |
| | default=None, |
| | help="A prompt that is used during validation to verify that the model is learning.", |
| | ) |
| | parser.add_argument( |
| | "--num_validation_images", |
| | type=int, |
| | default=4, |
| | help="Number of images that should be generated during validation with `validation_prompt`.", |
| | ) |
| | parser.add_argument( |
| | "--validation_epochs", |
| | type=int, |
| | default=50, |
| | help=( |
| | "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" |
| | " `args.validation_prompt` multiple times: `args.num_validation_images`." |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--with_prior_preservation", |
| | default=False, |
| | action="store_true", |
| | help="Flag to add prior preservation loss.", |
| | ) |
| | parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") |
| | parser.add_argument( |
| | "--num_class_images", |
| | type=int, |
| | default=100, |
| | help=( |
| | "Minimal class images for prior preservation loss. If there are not enough images already present in" |
| | " class_data_dir, additional images will be sampled with class_prompt." |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--output_dir", |
| | type=str, |
| | default="lora-dreambooth-model", |
| | help="The output directory where the model predictions and checkpoints will be written.", |
| | ) |
| | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") |
| | parser.add_argument( |
| | "--resolution", |
| | type=int, |
| | default=512, |
| | help=( |
| | "The resolution for input images, all the images in the train/validation dataset will be resized to this" |
| | " resolution" |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--center_crop", |
| | default=False, |
| | action="store_true", |
| | help=( |
| | "Whether to center crop the input images to the resolution. If not set, the images will be randomly" |
| | " cropped. The images will be resized to the resolution first before cropping." |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--train_text_encoder", |
| | action="store_true", |
| | help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", |
| | ) |
| | parser.add_argument( |
| | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." |
| | ) |
| | parser.add_argument( |
| | "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." |
| | ) |
| | parser.add_argument("--num_train_epochs", type=int, default=1) |
| | parser.add_argument( |
| | "--max_train_steps", |
| | type=int, |
| | default=None, |
| | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| | ) |
| | parser.add_argument( |
| | "--checkpointing_steps", |
| | type=int, |
| | default=500, |
| | help=( |
| | "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" |
| | " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" |
| | " training using `--resume_from_checkpoint`." |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--checkpoints_total_limit", |
| | type=int, |
| | default=None, |
| | help=("Max number of checkpoints to store."), |
| | ) |
| | parser.add_argument( |
| | "--resume_from_checkpoint", |
| | type=str, |
| | default=None, |
| | help=( |
| | "Whether training should be resumed from a previous checkpoint. Use a path saved by" |
| | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--gradient_accumulation_steps", |
| | type=int, |
| | default=1, |
| | help="Number of updates steps to accumulate before performing a backward/update pass.", |
| | ) |
| | parser.add_argument( |
| | "--gradient_checkpointing", |
| | action="store_true", |
| | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", |
| | ) |
| | parser.add_argument( |
| | "--learning_rate", |
| | type=float, |
| | default=5e-4, |
| | help="Initial learning rate (after the potential warmup period) to use.", |
| | ) |
| | parser.add_argument( |
| | "--scale_lr", |
| | action="store_true", |
| | default=False, |
| | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", |
| | ) |
| | parser.add_argument( |
| | "--lr_scheduler", |
| | type=str, |
| | default="constant", |
| | help=( |
| | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
| | ' "constant", "constant_with_warmup"]' |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." |
| | ) |
| | parser.add_argument( |
| | "--lr_num_cycles", |
| | type=int, |
| | default=1, |
| | help="Number of hard resets of the lr in cosine_with_restarts scheduler.", |
| | ) |
| | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") |
| | parser.add_argument( |
| | "--dataloader_num_workers", |
| | type=int, |
| | default=0, |
| | help=( |
| | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." |
| | ) |
| | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") |
| | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") |
| | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") |
| | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") |
| | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") |
| | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") |
| | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") |
| | parser.add_argument( |
| | "--hub_model_id", |
| | type=str, |
| | default=None, |
| | help="The name of the repository to keep in sync with the local `output_dir`.", |
| | ) |
| | parser.add_argument( |
| | "--logging_dir", |
| | type=str, |
| | default="logs", |
| | help=( |
| | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" |
| | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--allow_tf32", |
| | action="store_true", |
| | help=( |
| | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" |
| | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--report_to", |
| | type=str, |
| | default="tensorboard", |
| | help=( |
| | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' |
| | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--mixed_precision", |
| | type=str, |
| | default=None, |
| | choices=["no", "fp16", "bf16"], |
| | help=( |
| | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" |
| | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" |
| | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--prior_generation_precision", |
| | type=str, |
| | default=None, |
| | choices=["no", "fp32", "fp16", "bf16"], |
| | help=( |
| | "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" |
| | " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." |
| | ), |
| | ) |
| | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") |
| | parser.add_argument( |
| | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." |
| | ) |
| | parser.add_argument( |
| | "--pre_compute_text_embeddings", |
| | action="store_true", |
| | help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.", |
| | ) |
| | parser.add_argument( |
| | "--tokenizer_max_length", |
| | type=int, |
| | default=None, |
| | required=False, |
| | help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.", |
| | ) |
| | parser.add_argument( |
| | "--text_encoder_use_attention_mask", |
| | action="store_true", |
| | required=False, |
| | help="Whether to use attention mask for the text encoder", |
| | ) |
| | parser.add_argument( |
| | "--validation_images", |
| | required=False, |
| | default=None, |
| | nargs="+", |
| | help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.", |
| | ) |
| | parser.add_argument( |
| | "--class_labels_conditioning", |
| | required=False, |
| | default=None, |
| | help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.", |
| | ) |
| | parser.add_argument( |
| | "--rank", |
| | type=int, |
| | default=4, |
| | help=("The dimension of the LoRA update matrices."), |
| | ) |
| |
|
| | parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers") |
| |
|
| | parser.add_argument( |
| | "--image_interpolation_mode", |
| | type=str, |
| | default="lanczos", |
| | choices=[ |
| | f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") |
| | ], |
| | help="The image interpolation method to use for resizing images.", |
| | ) |
| |
|
| | if input_args is not None: |
| | args = parser.parse_args(input_args) |
| | else: |
| | args = parser.parse_args() |
| |
|
| | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) |
| | if env_local_rank != -1 and env_local_rank != args.local_rank: |
| | args.local_rank = env_local_rank |
| |
|
| | if args.with_prior_preservation: |
| | if args.class_data_dir is None: |
| | raise ValueError("You must specify a data directory for class images.") |
| | if args.class_prompt is None: |
| | raise ValueError("You must specify prompt for class images.") |
| | else: |
| | |
| | if args.class_data_dir is not None: |
| | warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") |
| | if args.class_prompt is not None: |
| | warnings.warn("You need not use --class_prompt without --with_prior_preservation.") |
| |
|
| | if args.train_text_encoder and args.pre_compute_text_embeddings: |
| | raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`") |
| |
|
| | return args |
| |
|
| |
|
| | class DreamBoothDataset(Dataset): |
| | """ |
| | A dataset to prepare the instance and class images with the prompts for fine-tuning the model. |
| | It pre-processes the images and the tokenizes prompts. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | instance_data_root, |
| | instance_prompt, |
| | tokenizer, |
| | class_data_root=None, |
| | class_prompt=None, |
| | class_num=None, |
| | size=512, |
| | center_crop=False, |
| | encoder_hidden_states=None, |
| | class_prompt_encoder_hidden_states=None, |
| | tokenizer_max_length=None, |
| | ): |
| | self.size = size |
| | self.center_crop = center_crop |
| | self.tokenizer = tokenizer |
| | self.encoder_hidden_states = encoder_hidden_states |
| | self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states |
| | self.tokenizer_max_length = tokenizer_max_length |
| |
|
| | self.instance_data_root = Path(instance_data_root) |
| | if not self.instance_data_root.exists(): |
| | raise ValueError("Instance images root doesn't exists.") |
| |
|
| | self.instance_images_path = list(Path(instance_data_root).iterdir()) |
| | self.num_instance_images = len(self.instance_images_path) |
| | self.instance_prompt = instance_prompt |
| | self._length = self.num_instance_images |
| |
|
| | if class_data_root is not None: |
| | self.class_data_root = Path(class_data_root) |
| | self.class_data_root.mkdir(parents=True, exist_ok=True) |
| | self.class_images_path = list(self.class_data_root.iterdir()) |
| | if class_num is not None: |
| | self.num_class_images = min(len(self.class_images_path), class_num) |
| | else: |
| | self.num_class_images = len(self.class_images_path) |
| | self._length = max(self.num_class_images, self.num_instance_images) |
| | self.class_prompt = class_prompt |
| | else: |
| | self.class_data_root = None |
| |
|
| | interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None) |
| | if interpolation is None: |
| | raise ValueError(f"Unsupported interpolation mode {interpolation=}.") |
| |
|
| | self.image_transforms = transforms.Compose( |
| | [ |
| | transforms.Resize(size, interpolation=interpolation), |
| | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.5], [0.5]), |
| | ] |
| | ) |
| |
|
| | def __len__(self): |
| | return self._length |
| |
|
| | def __getitem__(self, index): |
| | example = {} |
| | instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) |
| | instance_image = exif_transpose(instance_image) |
| |
|
| | if not instance_image.mode == "RGB": |
| | instance_image = instance_image.convert("RGB") |
| | example["instance_images"] = self.image_transforms(instance_image) |
| |
|
| | if self.encoder_hidden_states is not None: |
| | example["instance_prompt_ids"] = self.encoder_hidden_states |
| | else: |
| | text_inputs = tokenize_prompt( |
| | self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length |
| | ) |
| | example["instance_prompt_ids"] = text_inputs.input_ids |
| | example["instance_attention_mask"] = text_inputs.attention_mask |
| |
|
| | if self.class_data_root: |
| | class_image = Image.open(self.class_images_path[index % self.num_class_images]) |
| | class_image = exif_transpose(class_image) |
| |
|
| | if not class_image.mode == "RGB": |
| | class_image = class_image.convert("RGB") |
| | example["class_images"] = self.image_transforms(class_image) |
| |
|
| | if self.class_prompt_encoder_hidden_states is not None: |
| | example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states |
| | else: |
| | class_text_inputs = tokenize_prompt( |
| | self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length |
| | ) |
| | example["class_prompt_ids"] = class_text_inputs.input_ids |
| | example["class_attention_mask"] = class_text_inputs.attention_mask |
| |
|
| | return example |
| |
|
| |
|
| | def collate_fn(examples, with_prior_preservation=False): |
| | has_attention_mask = "instance_attention_mask" in examples[0] |
| |
|
| | input_ids = [example["instance_prompt_ids"] for example in examples] |
| | pixel_values = [example["instance_images"] for example in examples] |
| |
|
| | if has_attention_mask: |
| | attention_mask = [example["instance_attention_mask"] for example in examples] |
| |
|
| | |
| | |
| | if with_prior_preservation: |
| | input_ids += [example["class_prompt_ids"] for example in examples] |
| | pixel_values += [example["class_images"] for example in examples] |
| | if has_attention_mask: |
| | attention_mask += [example["class_attention_mask"] for example in examples] |
| |
|
| | pixel_values = torch.stack(pixel_values) |
| | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() |
| |
|
| | input_ids = torch.cat(input_ids, dim=0) |
| |
|
| | batch = { |
| | "input_ids": input_ids, |
| | "pixel_values": pixel_values, |
| | } |
| |
|
| | if has_attention_mask: |
| | batch["attention_mask"] = attention_mask |
| |
|
| | return batch |
| |
|
| |
|
| | class PromptDataset(Dataset): |
| | """A simple dataset to prepare the prompts to generate class images on multiple GPUs.""" |
| |
|
| | def __init__(self, prompt, num_samples): |
| | self.prompt = prompt |
| | self.num_samples = num_samples |
| |
|
| | def __len__(self): |
| | return self.num_samples |
| |
|
| | def __getitem__(self, index): |
| | example = {} |
| | example["prompt"] = self.prompt |
| | example["index"] = index |
| | return example |
| |
|
| |
|
| | def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): |
| | if tokenizer_max_length is not None: |
| | max_length = tokenizer_max_length |
| | else: |
| | max_length = tokenizer.model_max_length |
| |
|
| | text_inputs = tokenizer( |
| | prompt, |
| | truncation=True, |
| | padding="max_length", |
| | max_length=max_length, |
| | return_tensors="pt", |
| | ) |
| |
|
| | return text_inputs |
| |
|
| |
|
| | def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None): |
| | text_input_ids = input_ids.to(text_encoder.device) |
| |
|
| | if text_encoder_use_attention_mask: |
| | attention_mask = attention_mask.to(text_encoder.device) |
| | else: |
| | attention_mask = None |
| |
|
| | prompt_embeds = text_encoder( |
| | text_input_ids, |
| | attention_mask=attention_mask, |
| | return_dict=False, |
| | ) |
| | prompt_embeds = prompt_embeds[0] |
| |
|
| | return prompt_embeds |
| |
|
| |
|
| | def main(args): |
| | if args.report_to == "wandb" and args.hub_token is not None: |
| | raise ValueError( |
| | "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." |
| | " Please use `huggingface-cli login` to authenticate with the Hub." |
| | ) |
| |
|
| | logging_dir = Path(args.output_dir, args.logging_dir) |
| |
|
| | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) |
| |
|
| | accelerator = Accelerator( |
| | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| | mixed_precision=args.mixed_precision, |
| | log_with=args.report_to, |
| | project_config=accelerator_project_config, |
| | ) |
| |
|
| | |
| | if torch.backends.mps.is_available(): |
| | accelerator.native_amp = False |
| |
|
| | if args.report_to == "wandb": |
| | if not is_wandb_available(): |
| | raise ImportError("Make sure to install wandb if you want to use it for logging during training.") |
| |
|
| | |
| | |
| | |
| | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: |
| | raise ValueError( |
| | "Gradient accumulation is not supported when training the text encoder in distributed training. " |
| | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." |
| | ) |
| |
|
| | |
| | logging.basicConfig( |
| | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| | datefmt="%m/%d/%Y %H:%M:%S", |
| | level=logging.INFO, |
| | ) |
| | logger.info(accelerator.state, main_process_only=False) |
| | if accelerator.is_local_main_process: |
| | transformers.utils.logging.set_verbosity_warning() |
| | diffusers.utils.logging.set_verbosity_info() |
| | else: |
| | transformers.utils.logging.set_verbosity_error() |
| | diffusers.utils.logging.set_verbosity_error() |
| |
|
| | |
| | if args.seed is not None: |
| | set_seed(args.seed) |
| |
|
| | |
| | if args.with_prior_preservation: |
| | class_images_dir = Path(args.class_data_dir) |
| | if not class_images_dir.exists(): |
| | class_images_dir.mkdir(parents=True) |
| | cur_class_images = len(list(class_images_dir.iterdir())) |
| |
|
| | if cur_class_images < args.num_class_images: |
| | torch_dtype = torch.float16 if accelerator.device.type in ("cuda", "xpu") else torch.float32 |
| | if args.prior_generation_precision == "fp32": |
| | torch_dtype = torch.float32 |
| | elif args.prior_generation_precision == "fp16": |
| | torch_dtype = torch.float16 |
| | elif args.prior_generation_precision == "bf16": |
| | torch_dtype = torch.bfloat16 |
| | pipeline = DiffusionPipeline.from_pretrained( |
| | args.pretrained_model_name_or_path, |
| | torch_dtype=torch_dtype, |
| | safety_checker=None, |
| | revision=args.revision, |
| | variant=args.variant, |
| | ) |
| | pipeline.set_progress_bar_config(disable=True) |
| |
|
| | num_new_images = args.num_class_images - cur_class_images |
| | logger.info(f"Number of class images to sample: {num_new_images}.") |
| |
|
| | sample_dataset = PromptDataset(args.class_prompt, num_new_images) |
| | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) |
| |
|
| | sample_dataloader = accelerator.prepare(sample_dataloader) |
| | pipeline.to(accelerator.device) |
| |
|
| | for example in tqdm( |
| | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process |
| | ): |
| | images = pipeline(example["prompt"]).images |
| |
|
| | for i, image in enumerate(images): |
| | hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() |
| | image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" |
| | image.save(image_filename) |
| |
|
| | del pipeline |
| | free_memory() |
| |
|
| | |
| | if accelerator.is_main_process: |
| | if args.output_dir is not None: |
| | os.makedirs(args.output_dir, exist_ok=True) |
| |
|
| | if args.push_to_hub: |
| | repo_id = create_repo( |
| | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token |
| | ).repo_id |
| |
|
| | |
| | if args.tokenizer_name: |
| | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) |
| | elif args.pretrained_model_name_or_path: |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | args.pretrained_model_name_or_path, |
| | subfolder="tokenizer", |
| | revision=args.revision, |
| | use_fast=False, |
| | ) |
| |
|
| | |
| | text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) |
| |
|
| | |
| | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") |
| | text_encoder = text_encoder_cls.from_pretrained( |
| | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant |
| | ) |
| | try: |
| | vae = AutoencoderKL.from_pretrained( |
| | args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant |
| | ) |
| | except OSError: |
| | |
| | |
| | vae = None |
| |
|
| | unet = UNet2DConditionModel.from_pretrained( |
| | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant |
| | ) |
| |
|
| | |
| | if vae is not None: |
| | vae.requires_grad_(False) |
| | text_encoder.requires_grad_(False) |
| | unet.requires_grad_(False) |
| |
|
| | |
| | |
| | weight_dtype = torch.float32 |
| | if accelerator.mixed_precision == "fp16": |
| | weight_dtype = torch.float16 |
| | elif accelerator.mixed_precision == "bf16": |
| | weight_dtype = torch.bfloat16 |
| |
|
| | |
| | unet.to(accelerator.device, dtype=weight_dtype) |
| | if vae is not None: |
| | vae.to(accelerator.device, dtype=weight_dtype) |
| | text_encoder.to(accelerator.device, dtype=weight_dtype) |
| |
|
| | if args.enable_xformers_memory_efficient_attention: |
| | if is_xformers_available(): |
| | import xformers |
| |
|
| | xformers_version = version.parse(xformers.__version__) |
| | if xformers_version == version.parse("0.0.16"): |
| | logger.warning( |
| | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." |
| | ) |
| | unet.enable_xformers_memory_efficient_attention() |
| | else: |
| | raise ValueError("xformers is not available. Make sure it is installed correctly") |
| |
|
| | if args.gradient_checkpointing: |
| | unet.enable_gradient_checkpointing() |
| | if args.train_text_encoder: |
| | text_encoder.gradient_checkpointing_enable() |
| |
|
| | |
| | unet_lora_config = LoraConfig( |
| | r=args.rank, |
| | lora_alpha=args.rank, |
| | lora_dropout=args.lora_dropout, |
| | init_lora_weights="gaussian", |
| | target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"], |
| | ) |
| | unet.add_adapter(unet_lora_config) |
| |
|
| | |
| | if args.train_text_encoder: |
| | text_lora_config = LoraConfig( |
| | r=args.rank, |
| | lora_alpha=args.rank, |
| | lora_dropout=args.lora_dropout, |
| | init_lora_weights="gaussian", |
| | target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], |
| | ) |
| | text_encoder.add_adapter(text_lora_config) |
| |
|
| | def unwrap_model(model): |
| | model = accelerator.unwrap_model(model) |
| | model = model._orig_mod if is_compiled_module(model) else model |
| | return model |
| |
|
| | |
| | def save_model_hook(models, weights, output_dir): |
| | if accelerator.is_main_process: |
| | |
| | |
| | unet_lora_layers_to_save = None |
| | text_encoder_lora_layers_to_save = None |
| |
|
| | for model in models: |
| | if isinstance(model, type(unwrap_model(unet))): |
| | unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) |
| | elif isinstance(model, type(unwrap_model(text_encoder))): |
| | text_encoder_lora_layers_to_save = convert_state_dict_to_diffusers( |
| | get_peft_model_state_dict(model) |
| | ) |
| | else: |
| | raise ValueError(f"unexpected save model: {model.__class__}") |
| |
|
| | |
| | weights.pop() |
| |
|
| | StableDiffusionLoraLoaderMixin.save_lora_weights( |
| | output_dir, |
| | unet_lora_layers=unet_lora_layers_to_save, |
| | text_encoder_lora_layers=text_encoder_lora_layers_to_save, |
| | ) |
| |
|
| | def load_model_hook(models, input_dir): |
| | unet_ = None |
| | text_encoder_ = None |
| |
|
| | while len(models) > 0: |
| | model = models.pop() |
| |
|
| | if isinstance(model, type(unwrap_model(unet))): |
| | unet_ = model |
| | elif isinstance(model, type(unwrap_model(text_encoder))): |
| | text_encoder_ = model |
| | else: |
| | raise ValueError(f"unexpected save model: {model.__class__}") |
| |
|
| | lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir) |
| |
|
| | unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")} |
| | unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) |
| | incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") |
| |
|
| | if incompatible_keys is not None: |
| | |
| | unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) |
| | if unexpected_keys: |
| | logger.warning( |
| | f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " |
| | f" {unexpected_keys}. " |
| | ) |
| |
|
| | if args.train_text_encoder: |
| | _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_) |
| |
|
| | |
| | |
| | |
| | if args.mixed_precision == "fp16": |
| | models = [unet_] |
| | if args.train_text_encoder: |
| | models.append(text_encoder_) |
| |
|
| | |
| | cast_training_params(models, dtype=torch.float32) |
| |
|
| | accelerator.register_save_state_pre_hook(save_model_hook) |
| | accelerator.register_load_state_pre_hook(load_model_hook) |
| |
|
| | |
| | |
| | if args.allow_tf32: |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| |
|
| | if args.scale_lr: |
| | args.learning_rate = ( |
| | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes |
| | ) |
| |
|
| | |
| | if args.mixed_precision == "fp16": |
| | models = [unet] |
| | if args.train_text_encoder: |
| | models.append(text_encoder) |
| |
|
| | |
| | cast_training_params(models, dtype=torch.float32) |
| |
|
| | |
| | if args.use_8bit_adam: |
| | try: |
| | import bitsandbytes as bnb |
| | except ImportError: |
| | raise ImportError( |
| | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." |
| | ) |
| |
|
| | optimizer_class = bnb.optim.AdamW8bit |
| | else: |
| | optimizer_class = torch.optim.AdamW |
| |
|
| | |
| | params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters())) |
| | if args.train_text_encoder: |
| | params_to_optimize = params_to_optimize + list(filter(lambda p: p.requires_grad, text_encoder.parameters())) |
| |
|
| | optimizer = optimizer_class( |
| | params_to_optimize, |
| | lr=args.learning_rate, |
| | betas=(args.adam_beta1, args.adam_beta2), |
| | weight_decay=args.adam_weight_decay, |
| | eps=args.adam_epsilon, |
| | ) |
| |
|
| | if args.pre_compute_text_embeddings: |
| |
|
| | def compute_text_embeddings(prompt): |
| | with torch.no_grad(): |
| | text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length) |
| | prompt_embeds = encode_prompt( |
| | text_encoder, |
| | text_inputs.input_ids, |
| | text_inputs.attention_mask, |
| | text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, |
| | ) |
| |
|
| | return prompt_embeds |
| |
|
| | pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) |
| | validation_prompt_negative_prompt_embeds = compute_text_embeddings("") |
| |
|
| | if args.validation_prompt is not None: |
| | validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt) |
| | else: |
| | validation_prompt_encoder_hidden_states = None |
| |
|
| | if args.class_prompt is not None: |
| | pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt) |
| | else: |
| | pre_computed_class_prompt_encoder_hidden_states = None |
| |
|
| | text_encoder = None |
| | tokenizer = None |
| |
|
| | gc.collect() |
| | free_memory() |
| | else: |
| | pre_computed_encoder_hidden_states = None |
| | validation_prompt_encoder_hidden_states = None |
| | validation_prompt_negative_prompt_embeds = None |
| | pre_computed_class_prompt_encoder_hidden_states = None |
| |
|
| | |
| | train_dataset = DreamBoothDataset( |
| | instance_data_root=args.instance_data_dir, |
| | instance_prompt=args.instance_prompt, |
| | class_data_root=args.class_data_dir if args.with_prior_preservation else None, |
| | class_prompt=args.class_prompt, |
| | class_num=args.num_class_images, |
| | tokenizer=tokenizer, |
| | size=args.resolution, |
| | center_crop=args.center_crop, |
| | encoder_hidden_states=pre_computed_encoder_hidden_states, |
| | class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states, |
| | tokenizer_max_length=args.tokenizer_max_length, |
| | ) |
| |
|
| | train_dataloader = torch.utils.data.DataLoader( |
| | train_dataset, |
| | batch_size=args.train_batch_size, |
| | shuffle=True, |
| | collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), |
| | num_workers=args.dataloader_num_workers, |
| | ) |
| |
|
| | |
| | |
| | num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes |
| | if args.max_train_steps is None: |
| | len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) |
| | num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) |
| | num_training_steps_for_scheduler = ( |
| | args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch |
| | ) |
| | else: |
| | num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes |
| |
|
| | lr_scheduler = get_scheduler( |
| | args.lr_scheduler, |
| | optimizer=optimizer, |
| | num_warmup_steps=num_warmup_steps_for_scheduler, |
| | num_training_steps=num_training_steps_for_scheduler, |
| | num_cycles=args.lr_num_cycles, |
| | power=args.lr_power, |
| | ) |
| |
|
| | |
| | if args.train_text_encoder: |
| | unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( |
| | unet, text_encoder, optimizer, train_dataloader, lr_scheduler |
| | ) |
| | else: |
| | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( |
| | unet, optimizer, train_dataloader, lr_scheduler |
| | ) |
| |
|
| | |
| | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
| | if args.max_train_steps is None: |
| | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| | if num_training_steps_for_scheduler != args.max_train_steps: |
| | logger.warning( |
| | f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " |
| | f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " |
| | f"This inconsistency may result in the learning rate scheduler not functioning properly." |
| | ) |
| |
|
| | |
| | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
| |
|
| | |
| | |
| | if accelerator.is_main_process: |
| | tracker_config = vars(copy.deepcopy(args)) |
| | tracker_config.pop("validation_images") |
| | accelerator.init_trackers("dreambooth-lora", config=tracker_config) |
| |
|
| | |
| | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
| |
|
| | logger.info("***** Running training *****") |
| | logger.info(f" Num examples = {len(train_dataset)}") |
| | logger.info(f" Num batches each epoch = {len(train_dataloader)}") |
| | logger.info(f" Num Epochs = {args.num_train_epochs}") |
| | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") |
| | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") |
| | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") |
| | logger.info(f" Total optimization steps = {args.max_train_steps}") |
| | global_step = 0 |
| | first_epoch = 0 |
| |
|
| | |
| | if args.resume_from_checkpoint: |
| | if args.resume_from_checkpoint != "latest": |
| | path = os.path.basename(args.resume_from_checkpoint) |
| | else: |
| | |
| | dirs = os.listdir(args.output_dir) |
| | dirs = [d for d in dirs if d.startswith("checkpoint")] |
| | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) |
| | path = dirs[-1] if len(dirs) > 0 else None |
| |
|
| | if path is None: |
| | accelerator.print( |
| | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." |
| | ) |
| | args.resume_from_checkpoint = None |
| | initial_global_step = 0 |
| | else: |
| | accelerator.print(f"Resuming from checkpoint {path}") |
| | accelerator.load_state(os.path.join(args.output_dir, path)) |
| | global_step = int(path.split("-")[1]) |
| |
|
| | initial_global_step = global_step |
| | first_epoch = global_step // num_update_steps_per_epoch |
| | else: |
| | initial_global_step = 0 |
| |
|
| | progress_bar = tqdm( |
| | range(0, args.max_train_steps), |
| | initial=initial_global_step, |
| | desc="Steps", |
| | |
| | disable=not accelerator.is_local_main_process, |
| | ) |
| |
|
| | for epoch in range(first_epoch, args.num_train_epochs): |
| | unet.train() |
| | if args.train_text_encoder: |
| | text_encoder.train() |
| | for step, batch in enumerate(train_dataloader): |
| | with accelerator.accumulate(unet): |
| | pixel_values = batch["pixel_values"].to(dtype=weight_dtype) |
| |
|
| | if vae is not None: |
| | |
| | model_input = vae.encode(pixel_values).latent_dist.sample() |
| | model_input = model_input * vae.config.scaling_factor |
| | else: |
| | model_input = pixel_values |
| |
|
| | |
| | noise = torch.randn_like(model_input) |
| | bsz, channels, height, width = model_input.shape |
| | |
| | timesteps = torch.randint( |
| | 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device |
| | ) |
| | timesteps = timesteps.long() |
| |
|
| | |
| | |
| | noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) |
| |
|
| | |
| | if args.pre_compute_text_embeddings: |
| | encoder_hidden_states = batch["input_ids"] |
| | else: |
| | encoder_hidden_states = encode_prompt( |
| | text_encoder, |
| | batch["input_ids"], |
| | batch["attention_mask"], |
| | text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, |
| | ) |
| |
|
| | if unwrap_model(unet).config.in_channels == channels * 2: |
| | noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) |
| |
|
| | if args.class_labels_conditioning == "timesteps": |
| | class_labels = timesteps |
| | else: |
| | class_labels = None |
| |
|
| | |
| | model_pred = unet( |
| | noisy_model_input, |
| | timesteps, |
| | encoder_hidden_states, |
| | class_labels=class_labels, |
| | return_dict=False, |
| | )[0] |
| |
|
| | |
| | |
| | |
| | if model_pred.shape[1] == 6: |
| | model_pred, _ = torch.chunk(model_pred, 2, dim=1) |
| |
|
| | |
| | if noise_scheduler.config.prediction_type == "epsilon": |
| | target = noise |
| | elif noise_scheduler.config.prediction_type == "v_prediction": |
| | target = noise_scheduler.get_velocity(model_input, noise, timesteps) |
| | else: |
| | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
| |
|
| | if args.with_prior_preservation: |
| | |
| | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
| | target, target_prior = torch.chunk(target, 2, dim=0) |
| |
|
| | |
| | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| |
|
| | |
| | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") |
| |
|
| | |
| | loss = loss + args.prior_loss_weight * prior_loss |
| | else: |
| | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| |
|
| | accelerator.backward(loss) |
| | if accelerator.sync_gradients: |
| | accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm) |
| | optimizer.step() |
| | lr_scheduler.step() |
| | optimizer.zero_grad() |
| |
|
| | |
| | if accelerator.sync_gradients: |
| | progress_bar.update(1) |
| | global_step += 1 |
| |
|
| | if accelerator.is_main_process: |
| | if global_step % args.checkpointing_steps == 0: |
| | |
| | if args.checkpoints_total_limit is not None: |
| | checkpoints = os.listdir(args.output_dir) |
| | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] |
| | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) |
| |
|
| | |
| | if len(checkpoints) >= args.checkpoints_total_limit: |
| | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 |
| | removing_checkpoints = checkpoints[0:num_to_remove] |
| |
|
| | logger.info( |
| | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" |
| | ) |
| | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") |
| |
|
| | for removing_checkpoint in removing_checkpoints: |
| | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) |
| | shutil.rmtree(removing_checkpoint) |
| |
|
| | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") |
| | accelerator.save_state(save_path) |
| | logger.info(f"Saved state to {save_path}") |
| |
|
| | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} |
| | progress_bar.set_postfix(**logs) |
| | accelerator.log(logs, step=global_step) |
| |
|
| | if global_step >= args.max_train_steps: |
| | break |
| |
|
| | if accelerator.is_main_process: |
| | if args.validation_prompt is not None and epoch % args.validation_epochs == 0: |
| | |
| | pipeline = DiffusionPipeline.from_pretrained( |
| | args.pretrained_model_name_or_path, |
| | unet=unwrap_model(unet), |
| | text_encoder=None if args.pre_compute_text_embeddings else unwrap_model(text_encoder), |
| | revision=args.revision, |
| | variant=args.variant, |
| | torch_dtype=weight_dtype, |
| | ) |
| |
|
| | if args.pre_compute_text_embeddings: |
| | pipeline_args = { |
| | "prompt_embeds": validation_prompt_encoder_hidden_states, |
| | "negative_prompt_embeds": validation_prompt_negative_prompt_embeds, |
| | } |
| | else: |
| | pipeline_args = {"prompt": args.validation_prompt} |
| |
|
| | images = log_validation( |
| | pipeline, |
| | args, |
| | accelerator, |
| | pipeline_args, |
| | epoch, |
| | torch_dtype=weight_dtype, |
| | ) |
| |
|
| | |
| | accelerator.wait_for_everyone() |
| | if accelerator.is_main_process: |
| | unet = unwrap_model(unet) |
| | unet = unet.to(torch.float32) |
| |
|
| | unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) |
| |
|
| | if args.train_text_encoder: |
| | text_encoder = unwrap_model(text_encoder) |
| | text_encoder_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder)) |
| | else: |
| | text_encoder_state_dict = None |
| |
|
| | StableDiffusionLoraLoaderMixin.save_lora_weights( |
| | save_directory=args.output_dir, |
| | unet_lora_layers=unet_lora_state_dict, |
| | text_encoder_lora_layers=text_encoder_state_dict, |
| | ) |
| |
|
| | |
| | |
| | pipeline = DiffusionPipeline.from_pretrained( |
| | args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype |
| | ) |
| |
|
| | |
| | pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors") |
| |
|
| | |
| | images = [] |
| | if args.validation_prompt and args.num_validation_images > 0: |
| | pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25} |
| | images = log_validation( |
| | pipeline, |
| | args, |
| | accelerator, |
| | pipeline_args, |
| | epoch, |
| | is_final_validation=True, |
| | torch_dtype=weight_dtype, |
| | ) |
| |
|
| | if args.push_to_hub: |
| | save_model_card( |
| | repo_id, |
| | images=images, |
| | base_model=args.pretrained_model_name_or_path, |
| | train_text_encoder=args.train_text_encoder, |
| | prompt=args.instance_prompt, |
| | repo_folder=args.output_dir, |
| | pipeline=pipeline, |
| | ) |
| | upload_folder( |
| | repo_id=repo_id, |
| | folder_path=args.output_dir, |
| | commit_message="End of training", |
| | ignore_patterns=["step_*", "epoch_*"], |
| | ) |
| |
|
| | accelerator.end_training() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = parse_args() |
| | main(args) |
| |
|