|
|
|
|
|
|
| import argparse
|
| import gc
|
| import logging
|
| import math
|
| import os
|
| import shutil
|
| from pathlib import Path
|
|
|
| import accelerate
|
| import numpy as np
|
| import torch
|
| import torch.nn.functional as F
|
| import torch.utils.checkpoint
|
| import transformers
|
| from accelerate import Accelerator
|
| from accelerate.logging import get_logger
|
| from accelerate.utils import ProjectConfiguration, set_seed
|
| from huggingface_hub import create_repo, upload_folder
|
| from omegaconf import OmegaConf
|
| from packaging import version
|
| from PIL import Image
|
| from torchvision import transforms
|
| from tqdm.auto import tqdm
|
| from transformers import PretrainedConfig
|
|
|
| import diffusers
|
| import powerpaint.datasets
|
| from diffusers.optimization import get_scheduler
|
| from diffusers.training_utils import compute_snr
|
| from diffusers.utils import check_min_version, is_wandb_available
|
| from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
| from diffusers.utils.import_utils import is_xformers_available
|
| from diffusers.utils.torch_utils import is_compiled_module
|
| from powerpaint.datasets import ProbPickingDataset
|
| from powerpaint.models import BrushNetModel, UNet2DConditionModel
|
| from powerpaint.pipelines import StableDiffusionPowerPaintBrushNetPipeline
|
|
|
|
|
| if is_wandb_available():
|
| import wandb
|
|
|
|
|
| check_min_version("0.27.0.dev0")
|
|
|
| logger = get_logger(__name__)
|
|
|
|
|
| def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
|
| img_str = ""
|
| if image_logs is not None:
|
| img_str = "You can find some example images below.\n\n"
|
| for i, log in enumerate(image_logs):
|
| images = log["images"]
|
| validation_prompt = log["validation_prompt"]
|
| validation_image = log["validation_image"]
|
|
|
| validation_image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
| img_str += f"prompt: {validation_prompt}\n"
|
| images = [validation_image] + images
|
| img_str += f"\n"
|
|
|
| model_description = f"""
|
| # PowerPaint - {repo_id}
|
|
|
| These are PowerPaint weights trained on {base_model} with new type of conditioning.
|
| {img_str}
|
| """
|
| model_card = load_or_create_model_card(
|
| repo_id_or_path=repo_id,
|
| from_training=True,
|
| license="creativeml-openrail-m",
|
| base_model=base_model,
|
| model_description=model_description,
|
| inference=True,
|
| )
|
|
|
| tags = [
|
| "stable-diffusion",
|
| "stable-diffusion-diffusers",
|
| "text-to-image",
|
| "diffusers",
|
| "PowerPaint",
|
| "diffusers-training",
|
| ]
|
| model_card = populate_model_card(model_card, tags=tags)
|
|
|
| model_card.save(os.path.join(repo_folder, "README.md"))
|
|
|
|
|
| def log_validation(tokenizer, text_encoder, brushnet, args, accelerator, weight_dtype, step):
|
| logger.info("Running validation... ")
|
|
|
|
|
| pipe = StableDiffusionPowerPaintBrushNetPipeline.from_pretrained(
|
| args.pretrained_model_name_or_path,
|
| unet=UNet2DConditionModel.from_pretrained(
|
| args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
| ),
|
| tokenizer=tokenizer,
|
| text_encoder=accelerator.unwrap_model(text_encoder),
|
| brushnet=accelerator.unwrap_model(brushnet),
|
| safety_checker=None,
|
| revision=args.revision,
|
| variant=args.variant,
|
| torch_dtype=weight_dtype,
|
| local_files_only=True,
|
| )
|
|
|
| pipe = pipe.to(accelerator.device)
|
| pipe.set_progress_bar_config(disable=True)
|
|
|
| if args.enable_xformers_memory_efficient_attention:
|
| pipe.enable_xformers_memory_efficient_attention()
|
|
|
|
|
| image_logs = []
|
| for case in args.validation_data.cases:
|
| validation_prompts = case.prompt
|
| validation_image = Image.open(os.path.join(args.validation_data.data_root, case.image)).convert("RGB")
|
| validation_mask = Image.open(os.path.join(args.validation_data.data_root, case.mask))
|
| validation_mask = validation_mask.resize((validation_image.size[0], validation_image.size[1]), Image.NEAREST)
|
| validation_mask = validation_mask.convert("L")
|
| hole_value = (0, 0, 0)
|
| validation_image = Image.composite(
|
| Image.new("RGB", (validation_image.size[0], validation_image.size[1]), hole_value),
|
| validation_image,
|
| validation_mask.convert("L"),
|
| )
|
|
|
| image_grid = Image.new(
|
| "RGB",
|
| (validation_image.size[0] * (1 + len(validation_prompts)), validation_image.size[1]),
|
| (255, 255, 255),
|
| )
|
| image_grid.paste(validation_image, (0, 0))
|
| t2i_mask = Image.new("RGB", (validation_image.size[0], validation_image.size[1]), (255, 255, 255)).convert("L")
|
| t2i_image = Image.new("RGB", (validation_image.size[0], validation_image.size[1]), (0, 0, 0))
|
| for i, p in enumerate(validation_prompts):
|
| with torch.autocast(accelerator.device.type):
|
| image = pipe(
|
| promptA=p.promptA,
|
| promptB=p.promptB,
|
| prompt=p.prompt,
|
| negative_promptA=p.negative_promptA,
|
| negative_promptB=p.negative_promptB,
|
| negative_prompt=p.negative_prompt,
|
| tradeoff=p.tradeoff,
|
| image=validation_image if p.task != "t2i" else t2i_image,
|
| mask=validation_mask if p.task != "t2i" else t2i_mask,
|
| num_inference_steps=20,
|
| ).images[0]
|
| image_logs.append(image)
|
| image_grid.paste(image, (validation_image.size[0] * (i + 1), 0))
|
| image_grid.save(os.path.join(args.output_dir, f"{str(step).zfill(3)}_{os.path.basename(case.image)}"))
|
| gc.collect()
|
| torch.cuda.empty_cache()
|
|
|
| for tracker in accelerator.trackers:
|
| if tracker.name == "tensorboard":
|
| np_images = np.stack([np.asarray(img) for img in image_logs])
|
| tracker.writer.add_images("validation", np_images, step, dataformats="NHWC")
|
| elif tracker.name == "wandb":
|
| tracker.log(
|
| {
|
| "validation": [
|
| wandb.Image(image, caption=f"{p.task}")
|
| for image, p in zip(image_logs, args.validation_data.cases[0].prompt)
|
| ]
|
| }
|
| )
|
| else:
|
| logger.warning(f"image logging not implemented for {tracker.name}")
|
|
|
| del pipe
|
| gc.collect()
|
| torch.cuda.empty_cache()
|
|
|
| return image_logs
|
|
|
|
|
| def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
| text_encoder_config = PretrainedConfig.from_pretrained(
|
| pretrained_model_name_or_path,
|
| subfolder="text_encoder",
|
| revision=revision,
|
| )
|
| model_class = text_encoder_config.architectures[0]
|
|
|
| if model_class == "CLIPTextModel":
|
| from transformers import CLIPTextModel
|
|
|
| return CLIPTextModel
|
| elif model_class == "RobertaSeriesModelWithTransformation":
|
| from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
|
|
|
| return RobertaSeriesModelWithTransformation
|
| elif model_class == "T5EncoderModel":
|
| from transformers import T5EncoderModel
|
|
|
| return T5EncoderModel
|
| else:
|
| raise ValueError(f"{model_class} is not supported.")
|
|
|
|
|
| def parse_args(input_args=None):
|
| parser = argparse.ArgumentParser(
|
| description="Simple example of a PowerPaint based on brushnet architecture training script."
|
| )
|
|
|
| parser.add_argument(
|
| "--config",
|
| type=str,
|
| default=None,
|
| help="yaml for configuration",
|
| )
|
| parser.add_argument(
|
| "--pretrained_model_name_or_path",
|
| type=str,
|
| default=None,
|
| required=False,
|
| help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| )
|
| parser.add_argument(
|
| "--powerpaint_model_name_or_path",
|
| type=str,
|
| default=None,
|
| help="Path to pretrained powerpaint model or model identifier from huggingface.co/models."
|
| " If not specified powerpaint weights are initialized from unet.",
|
| )
|
| parser.add_argument(
|
| "--output_dir",
|
| type=str,
|
| default="runs/ppt2_bn",
|
| help="The output directory where the model predictions and checkpoints will be written.",
|
| )
|
| parser.add_argument(
|
| "--revision",
|
| type=str,
|
| default=None,
|
| required=False,
|
| help="Revision of pretrained model identifier from huggingface.co/models.",
|
| )
|
| parser.add_argument(
|
| "--variant",
|
| type=str,
|
| default=None,
|
| help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
| )
|
| parser.add_argument(
|
| "--tokenizer_name",
|
| type=str,
|
| default=None,
|
| help="Pretrained tokenizer name or path if not the same as model_name",
|
| )
|
| parser.add_argument(
|
| "--cache_dir",
|
| type=str,
|
| default=None,
|
| help="The directory where the downloaded models and datasets will be stored.",
|
| )
|
| parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
| parser.add_argument(
|
| "--resolution",
|
| type=int,
|
| default=512,
|
| help=(
|
| "The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| " resolution"
|
| ),
|
| )
|
| parser.add_argument(
|
| "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
| )
|
| parser.add_argument("--num_train_epochs", type=int, default=10000)
|
| parser.add_argument(
|
| "--max_train_steps",
|
| type=int,
|
| default=None,
|
| help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
| )
|
| parser.add_argument(
|
| "--checkpointing_steps",
|
| type=int,
|
| default=500,
|
| help=(
|
| "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
|
| "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
|
| "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
|
| "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
|
| "instructions."
|
| ),
|
| )
|
| parser.add_argument(
|
| "--checkpoints_total_limit",
|
| type=int,
|
| default=None,
|
| help=("Max number of checkpoints to store."),
|
| )
|
| parser.add_argument(
|
| "--resume_from_checkpoint",
|
| type=str,
|
| default=None,
|
| help=(
|
| "Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
| ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
| ),
|
| )
|
| parser.add_argument(
|
| "--gradient_accumulation_steps",
|
| type=int,
|
| default=1,
|
| help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| )
|
| parser.add_argument(
|
| "--gradient_checkpointing",
|
| action="store_true",
|
| help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
| )
|
| parser.add_argument(
|
| "--learning_rate",
|
| type=float,
|
| default=5e-6,
|
| help="Initial learning rate (after the potential warm up period) to use.",
|
| )
|
| parser.add_argument(
|
| "--scale_lr",
|
| action="store_true",
|
| default=False,
|
| help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
| )
|
| parser.add_argument(
|
| "--lr_scheduler",
|
| type=str,
|
| default="constant",
|
| help=(
|
| 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
| ' "constant", "constant_with_warmup"]'
|
| ),
|
| )
|
| parser.add_argument(
|
| "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warm up in the lr scheduler."
|
| )
|
| parser.add_argument(
|
| "--lr_num_cycles",
|
| type=int,
|
| default=1,
|
| help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
| )
|
| parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
|
| parser.add_argument(
|
| "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
| )
|
| parser.add_argument(
|
| "--dataloader_num_workers",
|
| type=int,
|
| default=0,
|
| help=(
|
| "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
| ),
|
| )
|
| parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
| parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
| parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
| parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
| parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
| parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
| parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
| parser.add_argument(
|
| "--hub_model_id",
|
| type=str,
|
| default=None,
|
| help="The name of the repository to keep in sync with the local `output_dir`.",
|
| )
|
| parser.add_argument(
|
| "--logging_dir",
|
| type=str,
|
| default="logs",
|
| help=(
|
| "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| ),
|
| )
|
| parser.add_argument(
|
| "--allow_tf32",
|
| action="store_true",
|
| help=(
|
| "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
| " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
| ),
|
| )
|
| parser.add_argument(
|
| "--report_to",
|
| type=str,
|
| default="tensorboard",
|
| help=(
|
| 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
| ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
| ),
|
| )
|
| parser.add_argument(
|
| "--mixed_precision",
|
| type=str,
|
| default=None,
|
| choices=["no", "fp16", "bf16"],
|
| help=(
|
| "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| ),
|
| )
|
| parser.add_argument(
|
| "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
| )
|
| parser.add_argument(
|
| "--set_grads_to_none",
|
| action="store_true",
|
| help=(
|
| "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
|
| " behaviors, so disable this argument if it causes any problems. More info:"
|
| " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
|
| ),
|
| )
|
| parser.add_argument(
|
| "--dataset_name",
|
| type=str,
|
| default=None,
|
| help=(
|
| "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
| " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
| " or to a folder containing files that 🤗 Datasets can understand."
|
| ),
|
| )
|
| parser.add_argument(
|
| "--dataset_config_name",
|
| type=str,
|
| default=None,
|
| help="The config of the Dataset, leave as None if there's only one config.",
|
| )
|
|
|
| parser.add_argument(
|
| "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
|
| )
|
| parser.add_argument(
|
| "--conditioning_image_column",
|
| type=str,
|
| default="conditioning_image",
|
| help="The column of the dataset containing the powerpaint conditioning image.",
|
| )
|
| parser.add_argument(
|
| "--caption_column",
|
| type=str,
|
| default="text",
|
| help="The column of the dataset containing a caption or a list of captions.",
|
| )
|
| parser.add_argument(
|
| "--max_train_samples",
|
| type=int,
|
| default=None,
|
| help=(
|
| "For debugging purposes or quicker training, truncate the number of training examples to this "
|
| "value if set."
|
| ),
|
| )
|
| parser.add_argument(
|
| "--proportion_empty_prompts",
|
| type=float,
|
| default=0,
|
| help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
|
| )
|
| parser.add_argument(
|
| "--snr_gamma",
|
| type=float,
|
| default=None,
|
| help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
|
| "More details here: https://arxiv.org/abs/2303.09556.",
|
| )
|
| parser.add_argument(
|
| "--validation_steps",
|
| type=int,
|
| default=100,
|
| help=(
|
| "Run validation every X steps. Validation consists of running the prompt"
|
| " `args.validation_prompt` multiple times: `args.num_validation_images`"
|
| " and logging the images."
|
| ),
|
| )
|
| parser.add_argument(
|
| "--tracker_project_name",
|
| type=str,
|
| default="train_powerpaint_brushnet",
|
| help=(
|
| "The `project_name` argument passed to Accelerator.init_trackers for"
|
| " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
|
| ),
|
| )
|
|
|
| if input_args is not None:
|
| args = parser.parse_args(input_args)
|
| else:
|
| args = parser.parse_args()
|
|
|
|
|
| if args.config is not None:
|
| config = OmegaConf.load(args.config)
|
| for k, v in config.items():
|
| args.__dict__[k] = v
|
|
|
| if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
|
| raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
|
|
|
| if args.resolution % 8 != 0:
|
| raise ValueError(
|
| "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the brushnet encoder."
|
| )
|
|
|
| return args
|
|
|
|
|
| def main(args):
|
| if args.report_to == "wandb" and args.hub_token is not None:
|
| raise ValueError(
|
| "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
| " Please use `huggingface-cli login` to authenticate with the Hub."
|
| )
|
|
|
| logging_dir = Path(args.output_dir, args.logging_dir)
|
| accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
| accelerator = Accelerator(
|
| gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| mixed_precision=args.mixed_precision,
|
| log_with=args.report_to,
|
| project_config=accelerator_project_config,
|
| )
|
|
|
|
|
| logging.basicConfig(
|
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| datefmt="%m/%d/%Y %H:%M:%S",
|
| level=logging.INFO,
|
| )
|
| logger.info(accelerator.state, main_process_only=False)
|
| if accelerator.is_local_main_process:
|
| transformers.utils.logging.set_verbosity_warning()
|
| diffusers.utils.logging.set_verbosity_info()
|
| else:
|
| transformers.utils.logging.set_verbosity_error()
|
| diffusers.utils.logging.set_verbosity_error()
|
|
|
|
|
| if args.seed is not None:
|
| torch.manual_seed(args.seed)
|
| set_seed(args.seed)
|
|
|
|
|
| if accelerator.is_main_process:
|
| if args.output_dir is not None:
|
| os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
|
| to_save_config = OmegaConf.create(vars(args))
|
| OmegaConf.save(config=to_save_config, f=os.path.join(args.output_dir, "training_config.yaml"))
|
|
|
| if args.push_to_hub:
|
| repo_id = create_repo(
|
| repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
| ).repo_id
|
|
|
|
|
|
|
| weight_dtype = torch.float32
|
| if accelerator.mixed_precision == "fp16":
|
| weight_dtype = torch.float16
|
| elif accelerator.mixed_precision == "bf16":
|
| weight_dtype = torch.bfloat16
|
|
|
|
|
| pipe = StableDiffusionPowerPaintBrushNetPipeline.from_pretrained(
|
| args.pretrained_model_name_or_path,
|
| unet=UNet2DConditionModel.from_pretrained(
|
| args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
| ),
|
| safety_checker=None,
|
| revision=args.revision,
|
| variant=args.variant,
|
| torch_dtype=weight_dtype,
|
| local_files_only=True,
|
| )
|
|
|
| if args.powerpaint_model_name_or_path:
|
| logger.info("Loading existing powerpaint weights")
|
| pipe.brushnet = BrushNetModel.from_pretrained(args.powerpaint_model_name_or_path)
|
|
|
|
|
| def unwrap_model(model):
|
| model = accelerator.unwrap_model(model)
|
| model = model._orig_mod if is_compiled_module(model) else model
|
| return model
|
|
|
|
|
| placeholder_tokens = [v.placeholder_tokens for k, v in args.task_prompt.items()]
|
| initializer_token = [v.initializer_token for k, v in args.task_prompt.items()]
|
| num_vectors_per_token = [v.num_vectors_per_token for k, v in args.task_prompt.items()]
|
| placeholder_token_ids = pipe.add_tokens(
|
| placeholder_tokens, initializer_token, num_vectors_per_token, initialize_parameters=True
|
| )
|
|
|
| vae, tokenizer, unet, noise_scheduler = pipe.vae, pipe.tokenizer, pipe.unet, pipe.scheduler
|
| text_encoder, brushnet = pipe.text_encoder.to(torch.float32), pipe.brushnet.to(torch.float32)
|
| text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
|
|
|
|
|
| if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
|
|
| def save_model_hook(models, weights, output_dir):
|
| if accelerator.is_main_process:
|
| for model in models:
|
| sub_dir = "brushnet" if isinstance(model, type(unwrap_model(brushnet))) else "text_encoder"
|
| model.save_pretrained(os.path.join(output_dir, sub_dir))
|
|
|
|
|
| weights.pop()
|
|
|
| def load_model_hook(models, input_dir):
|
| while len(models) > 0:
|
| model = models.pop()
|
|
|
| if isinstance(model, type(unwrap_model(text_encoder))):
|
|
|
| load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
|
| model.config = load_model.config
|
| else:
|
|
|
| load_model = BrushNetModel.from_pretrained(input_dir, subfolder="brushnet")
|
| model.register_to_config(**load_model.config)
|
|
|
| model.load_state_dict(load_model.state_dict())
|
| del load_model
|
|
|
| accelerator.register_save_state_pre_hook(save_model_hook)
|
| accelerator.register_load_state_pre_hook(load_model_hook)
|
|
|
| if args.gradient_checkpointing:
|
| brushnet.enable_gradient_checkpointing()
|
|
|
| if args.enable_xformers_memory_efficient_attention:
|
| if is_xformers_available():
|
| import xformers
|
|
|
| xformers_version = version.parse(xformers.__version__)
|
| if xformers_version == version.parse("0.0.16"):
|
| logger.warn(
|
| "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
| )
|
| unet.enable_xformers_memory_efficient_attention()
|
| brushnet.enable_xformers_memory_efficient_attention()
|
| else:
|
| raise ValueError("xformers is not available. Make sure it is installed correctly")
|
|
|
|
|
| low_precision_error_string = (
|
| " Please make sure to always have all model weights in full float32 precision when starting training - even if"
|
| " doing mixed precision training, copy of the weights should still be float32."
|
| )
|
|
|
| if unwrap_model(brushnet).dtype != torch.float32:
|
| raise ValueError(f"BrushNet loaded as datatype {unwrap_model(brushnet).dtype}. {low_precision_error_string}")
|
|
|
|
|
|
|
| if args.allow_tf32:
|
| torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
| if args.scale_lr:
|
| args.learning_rate = (
|
| args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
| )
|
|
|
|
|
| if args.use_8bit_adam:
|
| try:
|
| import bitsandbytes as bnb
|
| except ImportError:
|
| raise ImportError(
|
| "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
| )
|
|
|
| optimizer_class = bnb.optim.AdamW8bit
|
| else:
|
| optimizer_class = torch.optim.AdamW
|
|
|
|
|
| vae.requires_grad_(False)
|
| unet.requires_grad_(False)
|
|
|
|
|
| text_encoder.text_model.encoder.requires_grad_(False)
|
| text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
| text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
|
|
| optimizer = optimizer_class(
|
| list(brushnet.parameters()) + list(text_encoder.get_input_embeddings().parameters()),
|
| lr=args.learning_rate,
|
| betas=(args.adam_beta1, args.adam_beta2),
|
| weight_decay=args.adam_weight_decay,
|
| eps=args.adam_epsilon,
|
| )
|
|
|
|
|
| train_transforms = transforms.Compose(
|
| [
|
| transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
| transforms.CenterCrop(args.resolution),
|
| transforms.RandomHorizontalFlip(),
|
| transforms.ToTensor(),
|
| transforms.Normalize([0.5], [0.5]),
|
| ]
|
| )
|
|
|
|
|
|
|
| datasets_list = []
|
| for d in args.train_data.datasets:
|
| dataset_class = getattr(powerpaint.datasets, d.dataset_class)
|
| dataset_ = dataset_class(train_transforms, pipe, args.task_prompt, **d)
|
| datasets_list.append({"dataset": dataset_, "prob": d.prob})
|
|
|
| train_dataset = ProbPickingDataset(datasets_list)
|
|
|
| with accelerator.main_process_first():
|
| if args.max_train_samples is not None:
|
| train_dataset = train_dataset.shuffle(seed=args.seed).select(range(args.max_train_samples))
|
|
|
| train_dataloader = torch.utils.data.DataLoader(
|
| train_dataset,
|
| batch_size=args.train_batch_size,
|
| num_workers=args.dataloader_num_workers,
|
| )
|
|
|
|
|
| overrode_max_train_steps = False
|
| num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| if args.max_train_steps is None:
|
| args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| overrode_max_train_steps = True
|
|
|
| lr_scheduler = get_scheduler(
|
| args.lr_scheduler,
|
| optimizer=optimizer,
|
| num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
| num_training_steps=args.max_train_steps * accelerator.num_processes,
|
| num_cycles=args.lr_num_cycles,
|
| power=args.lr_power,
|
| )
|
|
|
| brushnet.train()
|
| text_encoder.train()
|
|
|
| brushnet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| brushnet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
| )
|
|
|
|
|
| vae.to(accelerator.device, dtype=weight_dtype)
|
| unet.to(accelerator.device, dtype=weight_dtype)
|
|
|
|
|
| num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| if overrode_max_train_steps:
|
| args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
|
|
| args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
|
|
|
|
|
|
| if accelerator.is_main_process:
|
| tracker_config = dict(vars(args))
|
|
|
|
|
| pop_list = []
|
| for k, v in tracker_config.items():
|
| if not isinstance(v, (int, float, str, bool, torch.Tensor)):
|
| pop_list.append(k)
|
| logger.info(f"Removed {k} (type:{type(v)}) from tracker_config")
|
| for k in pop_list:
|
| tracker_config.pop(k)
|
|
|
| accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
|
|
|
|
|
| total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
|
|
| logger.info(f"***** Running training for {args.tracker_project_name} *****")
|
| logger.info(f" Num examples = {len(train_dataset)}")
|
| logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| logger.info(f" Total optimization steps = {int(args.max_train_steps)}")
|
| global_step = 0
|
| first_epoch = 0
|
|
|
|
|
| if args.resume_from_checkpoint:
|
| if args.resume_from_checkpoint != "latest":
|
| path = os.path.basename(args.resume_from_checkpoint)
|
| else:
|
|
|
| dirs = os.listdir(args.output_dir)
|
| dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| path = dirs[-1] if len(dirs) > 0 else None
|
|
|
| if path is None:
|
| logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
|
| args.resume_from_checkpoint = None
|
| initial_global_step = 0
|
| else:
|
| logger.info(f"Resuming from checkpoint {path}")
|
| accelerator.load_state(os.path.join(args.output_dir, path), map_location="cpu")
|
| global_step = int(path.split("-")[1])
|
|
|
| initial_global_step = global_step
|
| first_epoch = global_step // num_update_steps_per_epoch
|
| else:
|
| initial_global_step = 0
|
|
|
| progress_bar = tqdm(
|
| range(0, int(args.max_train_steps)),
|
| initial=initial_global_step,
|
| desc="Steps",
|
|
|
| disable=not accelerator.is_local_main_process,
|
| )
|
|
|
| image_logs = None
|
|
|
|
|
| orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
|
|
|
| for _ in range(first_epoch, args.num_train_epochs):
|
| train_loss = 0.0
|
| for batch in train_dataloader:
|
| with accelerator.accumulate(brushnet):
|
|
|
| latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
|
| latents = latents * vae.config.scaling_factor
|
|
|
|
|
|
|
|
|
| mask = torch.nn.functional.interpolate(batch["mask"], size=(64, 64))
|
| mask_image = batch["pixel_values"] * (batch["mask"] < 0.5)
|
|
|
| mask_image = mask_image - batch["mask"]
|
| mask_image_latents = vae.encode(mask_image.to(weight_dtype)).latent_dist.sample()
|
| mask_image_latents = (mask_image_latents * vae.config.scaling_factor).to(weight_dtype)
|
|
|
| conditioning_latents = torch.concat([mask, mask_image_latents], 1)
|
|
|
|
|
| noise = torch.randn_like(latents)
|
| bsz = latents.shape[0]
|
|
|
|
|
| timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
| timesteps = timesteps.long()
|
|
|
|
|
|
|
| noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
|
|
|
|
| encoder_hidden_states_unet = text_encoder(batch["input_ids"], return_dict=False)[0]
|
|
|
|
|
| encoder_hidden_statesA = text_encoder(batch["input_idsA"], return_dict=False)[0]
|
| encoder_hidden_statesB = text_encoder(batch["input_idsB"], return_dict=False)[0]
|
|
|
|
|
| tradeoff = batch["tradeoff"].unsqueeze(-1)
|
| encoder_hidden_states_brushnet = (
|
| tradeoff[:, 0:1, :] * encoder_hidden_statesA + tradeoff[:, 1:, :] * encoder_hidden_statesB.detach()
|
| )
|
|
|
|
|
| down_block_res_samples, mid_block_res_sample, up_block_res_samples = brushnet(
|
| noisy_latents,
|
| timesteps,
|
| encoder_hidden_states=encoder_hidden_states_brushnet.to(weight_dtype),
|
| brushnet_cond=conditioning_latents,
|
| return_dict=False,
|
| )
|
|
|
|
|
| model_pred = unet(
|
| noisy_latents,
|
| timesteps,
|
| encoder_hidden_states=encoder_hidden_states_unet.detach().to(weight_dtype),
|
| down_block_add_samples=[sample.to(dtype=weight_dtype) for sample in down_block_res_samples],
|
| mid_block_add_sample=mid_block_res_sample.to(dtype=weight_dtype),
|
| up_block_add_samples=[sample.to(dtype=weight_dtype) for sample in up_block_res_samples],
|
| return_dict=False,
|
| )[0]
|
|
|
|
|
| if noise_scheduler.config.prediction_type == "epsilon":
|
| target = noise
|
| elif noise_scheduler.config.prediction_type == "v_prediction":
|
| target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
| else:
|
| raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
|
|
| if args.snr_gamma is None:
|
| loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
| else:
|
|
|
|
|
|
|
| snr = compute_snr(noise_scheduler, timesteps)
|
| mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
| dim=1
|
| )[0]
|
| if noise_scheduler.config.prediction_type == "epsilon":
|
| mse_loss_weights = mse_loss_weights / snr
|
| elif noise_scheduler.config.prediction_type == "v_prediction":
|
| mse_loss_weights = mse_loss_weights / (snr + 1)
|
|
|
| loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
| loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
| loss = loss.mean()
|
|
|
|
|
| avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
| train_loss += avg_loss.item() / args.gradient_accumulation_steps
|
|
|
| accelerator.backward(loss)
|
| if accelerator.sync_gradients:
|
| params_to_clip = list(brushnet.parameters()) + list(
|
| accelerator.unwrap_model(text_encoder).get_input_embeddings().parameters()
|
| )
|
| accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| optimizer.step()
|
| lr_scheduler.step()
|
| optimizer.zero_grad()
|
|
|
|
|
| index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
|
| index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
|
|
|
| with torch.no_grad():
|
| accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
|
| orig_embeds_params[index_no_updates]
|
| )
|
|
|
|
|
| if accelerator.sync_gradients:
|
| progress_bar.update(1)
|
| global_step += 1
|
| accelerator.log({"train_loss": train_loss}, step=global_step)
|
| train_loss = 0.0
|
|
|
| if accelerator.is_main_process:
|
| if global_step % args.checkpointing_steps == 0:
|
|
|
| if args.checkpoints_total_limit is not None:
|
| checkpoints = os.listdir(args.output_dir)
|
| checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
| checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
|
|
|
|
| if len(checkpoints) >= args.checkpoints_total_limit:
|
| num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
| removing_checkpoints = checkpoints[0:num_to_remove]
|
|
|
| logger.info(
|
| f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
| )
|
| logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
|
|
| for removing_checkpoint in removing_checkpoints:
|
| removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
| shutil.rmtree(removing_checkpoint)
|
|
|
| save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| accelerator.save_state(save_path)
|
| logger.info(f"Saved state to {save_path}")
|
|
|
| if hasattr(args, "validation_data") and global_step % args.validation_steps == 0:
|
| image_logs = log_validation(
|
| tokenizer,
|
| text_encoder,
|
| brushnet,
|
| args,
|
| accelerator,
|
| weight_dtype,
|
| global_step,
|
| )
|
|
|
| logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| progress_bar.set_postfix(**logs)
|
|
|
| if global_step >= args.max_train_steps:
|
| break
|
|
|
|
|
| accelerator.wait_for_everyone()
|
| if accelerator.is_main_process:
|
| brushnet = unwrap_model(brushnet)
|
| brushnet.save_pretrained(args.output_dir)
|
|
|
|
|
| image_logs = None
|
| if hasattr(args, "validation_data"):
|
| image_logs = log_validation(
|
| tokenizer,
|
| text_encoder,
|
| brushnet,
|
| args,
|
| accelerator,
|
| weight_dtype,
|
| global_step,
|
| )
|
|
|
| if args.push_to_hub:
|
| save_model_card(
|
| repo_id,
|
| image_logs=image_logs,
|
| base_model=args.pretrained_model_name_or_path,
|
| repo_folder=args.output_dir,
|
| )
|
| upload_folder(
|
| repo_id=repo_id,
|
| folder_path=args.output_dir,
|
| commit_message="End of training",
|
| ignore_patterns=["step_*", "epoch_*"],
|
| )
|
|
|
| accelerator.end_training()
|
|
|
|
|
| if __name__ == "__main__":
|
| args = parse_args()
|
| main(args)
|
|
|