| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| import itertools |
| import logging |
| import math |
| import os |
| from pathlib import Path |
|
|
| import datasets |
| import diffusers |
| import torch |
| import torch.nn.functional as F |
| import torch.utils.checkpoint |
| import transformers |
| from accelerate import Accelerator |
| from accelerate.logging import get_logger |
| from accelerate.utils import set_seed |
| from diffusers import ( |
| AutoencoderKL, |
| DDIMScheduler, |
| ) |
| from diffusers.optimization import get_scheduler |
| from diffusers.utils import check_min_version |
| from diffusers.utils.import_utils import is_xformers_available |
| from packaging import version |
| from tqdm.auto import tqdm |
| from transformers import AutoTokenizer |
| from utils.args_loader import ( |
| import_model_class_from_model_name_or_path, |
| parse_args, |
| ) |
| from utils.dataset import collate_fn, log_validation, make_dataset |
| from utils.light_controlnet import ControlNetModel |
| from utils.tracemalloc import TorchTracemalloc, b2mb |
| from utils.unet_2d_condition import UNet2DConditionNewModel |
|
|
| from peft import BOFTConfig, get_peft_model |
| from peft.peft_model import PeftModel |
|
|
|
|
| |
| check_min_version("0.16.0.dev0") |
|
|
| logger = get_logger(__name__) |
|
|
| UNET_TARGET_MODULES = ["to_q", "to_v", "to_k", "query", "value", "key"] |
|
|
| TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] |
|
|
|
|
| @torch.no_grad() |
| def save_adaptor(accelerator, output_dir, nets_dict): |
| for net_key in nets_dict.keys(): |
| net_model = nets_dict[net_key] |
| unwarpped_net = accelerator.unwrap_model(net_model) |
|
|
| if isinstance(unwarpped_net, PeftModel): |
| unwarpped_net.save_pretrained( |
| os.path.join(output_dir, net_key), |
| state_dict=accelerator.get_state_dict(net_model), |
| safe_serialization=True, |
| ) |
| else: |
| accelerator.save_model( |
| unwarpped_net, |
| os.path.join(output_dir, net_key), |
| safe_serialization=True, |
| ) |
|
|
|
|
| def main(args): |
| logging_dir = Path(args.output_dir, args.logging_dir) |
|
|
| accelerator = Accelerator( |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| mixed_precision=args.mixed_precision, |
| log_with=args.report_to, |
| project_dir=logging_dir, |
| ) |
|
|
| if args.report_to == "wandb": |
| wandb_init = { |
| "wandb": { |
| "name": args.wandb_run_name, |
| "mode": "online", |
| } |
| } |
|
|
| |
| logging.basicConfig( |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| datefmt="%m/%d/%Y %H:%M:%S", |
| level=logging.INFO, |
| ) |
|
|
| logger.info(accelerator.state, main_process_only=False) |
|
|
| if accelerator.is_local_main_process: |
| datasets.utils.logging.set_verbosity_warning() |
| transformers.utils.logging.set_verbosity_warning() |
| diffusers.utils.logging.set_verbosity_info() |
| else: |
| datasets.utils.logging.set_verbosity_error() |
| transformers.utils.logging.set_verbosity_error() |
| diffusers.utils.logging.set_verbosity_error() |
|
|
| |
| if args.seed is not None: |
| set_seed(args.seed) |
|
|
| |
| if accelerator.is_main_process: |
| if args.output_dir is not None: |
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| |
| if args.tokenizer_name: |
| tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) |
| elif args.pretrained_model_name_or_path: |
| tokenizer = AutoTokenizer.from_pretrained( |
| args.pretrained_model_name_or_path, |
| subfolder="tokenizer", |
| revision=args.revision, |
| use_fast=False, |
| ) |
|
|
| |
| text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) |
|
|
| |
| noise_scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") |
|
|
| text_encoder = text_encoder_cls.from_pretrained( |
| args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision |
| ) |
| vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) |
| unet = UNet2DConditionNewModel.from_pretrained( |
| args.pretrained_model_name_or_path, |
| subfolder="unet", |
| revision=args.revision, |
| ) |
|
|
| controlnet = ControlNetModel() |
|
|
| if args.controlnet_model_name_or_path != "": |
| logger.info(f"Loading existing controlnet weights from {args.controlnet_model_name_or_path}") |
| controlnet.load_state_dict(torch.load(args.controlnet_model_name_or_path)) |
|
|
| if args.use_boft: |
| config = BOFTConfig( |
| boft_block_size=args.boft_block_size, |
| boft_block_num=args.boft_block_num, |
| boft_n_butterfly_factor=args.boft_n_butterfly_factor, |
| target_modules=UNET_TARGET_MODULES, |
| boft_dropout=args.boft_dropout, |
| bias=args.boft_bias, |
| ) |
| unet = get_peft_model(unet, config) |
| unet.print_trainable_parameters() |
|
|
| vae.requires_grad_(False) |
| controlnet.requires_grad_(True) |
|
|
| if not args.train_text_encoder: |
| text_encoder.requires_grad_(False) |
|
|
| unet.train() |
| controlnet.train() |
|
|
| if args.train_text_encoder and args.use_boft: |
| config = BOFTConfig( |
| boft_block_size=args.boft_block_size, |
| boft_block_num=args.boft_block_num, |
| boft_n_butterfly_factor=args.boft_n_butterfly_factor, |
| target_modules=TEXT_ENCODER_TARGET_MODULES, |
| boft_dropout=args.boft_dropout, |
| bias=args.boft_bias, |
| ) |
| text_encoder = get_peft_model(text_encoder, config, adapter_name=args.wandb_run_name) |
| text_encoder.print_trainable_parameters() |
|
|
| if args.train_text_encoder: |
| text_encoder.train() |
|
|
| |
| |
| weight_dtype = torch.float32 |
| if accelerator.mixed_precision == "fp16": |
| weight_dtype = torch.float16 |
| elif accelerator.mixed_precision == "bf16": |
| weight_dtype = torch.bfloat16 |
|
|
| |
| unet.to(accelerator.device, dtype=weight_dtype) |
| vae.to(accelerator.device, dtype=weight_dtype) |
| controlnet.to(accelerator.device, dtype=weight_dtype) |
|
|
| if not args.train_text_encoder: |
| text_encoder.to(accelerator.device, dtype=weight_dtype) |
|
|
| if args.enable_xformers_memory_efficient_attention: |
| if is_xformers_available(): |
| import xformers |
|
|
| xformers_version = version.parse(xformers.__version__) |
| if xformers_version == version.parse("0.0.16"): |
| logger.warning( |
| "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." |
| ) |
| unet.enable_xformers_memory_efficient_attention() |
| controlnet.enable_xformers_memory_efficient_attention() |
| if args.train_text_encoder and not (args.use_lora or args.use_boft or args.use_oft): |
| text_encoder.enable_xformers_memory_efficient_attention() |
| else: |
| raise ValueError("xformers is not available. Make sure it is installed correctly") |
|
|
| if args.gradient_checkpointing: |
| controlnet.enable_gradient_checkpointing() |
| unet.enable_gradient_checkpointing() |
| if args.train_text_encoder and not (args.use_lora or args.use_boft or args.use_oft): |
| text_encoder.gradient_checkpointing_enable() |
|
|
| |
| low_precision_error_string = ( |
| " Please make sure to always have all model weights in full float32 precision when starting training - even if" |
| " doing mixed precision training, copy of the weights should still be float32." |
| ) |
|
|
| if accelerator.unwrap_model(controlnet).dtype != torch.float32: |
| raise ValueError( |
| f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}" |
| ) |
|
|
| if accelerator.unwrap_model(unet).dtype != torch.float32: |
| raise ValueError( |
| f"UNet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" |
| ) |
|
|
| |
| |
| if args.allow_tf32: |
| torch.backends.cuda.matmul.allow_tf32 = True |
|
|
| if args.scale_lr: |
| args.learning_rate = ( |
| args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes |
| ) |
|
|
| |
| if args.use_8bit_adam: |
| try: |
| import bitsandbytes as bnb |
| except ImportError: |
| raise ImportError( |
| "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." |
| ) |
|
|
| optimizer_class = bnb.optim.AdamW8bit |
| else: |
| optimizer_class = torch.optim.AdamW |
|
|
| params_to_optimize = [param for param in controlnet.parameters() if param.requires_grad] |
| params_to_optimize += [param for param in unet.parameters() if param.requires_grad] |
|
|
| if args.train_text_encoder: |
| params_to_optimize += [param for param in text_encoder.parameters() if param.requires_grad] |
|
|
| |
| optimizer = optimizer_class( |
| params_to_optimize, |
| lr=args.learning_rate, |
| betas=(args.adam_beta1, args.adam_beta2), |
| weight_decay=args.adam_weight_decay, |
| eps=args.adam_epsilon, |
| ) |
|
|
| |
| train_dataset = make_dataset(args, tokenizer, accelerator, "train") |
| val_dataset = make_dataset(args, tokenizer, accelerator, "test") |
|
|
| train_dataloader = torch.utils.data.DataLoader( |
| train_dataset, |
| shuffle=True, |
| collate_fn=collate_fn, |
| batch_size=args.train_batch_size, |
| num_workers=args.dataloader_num_workers, |
| ) |
|
|
| |
| overrode_max_train_steps = False |
| num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
| if args.max_train_steps is None: |
| args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| overrode_max_train_steps = True |
|
|
| lr_scheduler = get_scheduler( |
| args.lr_scheduler, |
| optimizer=optimizer, |
| num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
| num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| num_cycles=args.lr_num_cycles, |
| power=args.lr_power, |
| ) |
|
|
| |
| controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( |
| controlnet, optimizer, train_dataloader, lr_scheduler |
| ) |
|
|
| if args.train_text_encoder: |
| text_encoder = accelerator.prepare(text_encoder) |
|
|
| |
| num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
| if overrode_max_train_steps: |
| args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| |
| args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
|
|
| |
| |
| if accelerator.is_main_process: |
| accelerator.init_trackers(args.wandb_project_name, config=vars(args), init_kwargs=wandb_init) |
|
|
| |
| total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
|
|
| logger.info("***** Running training *****") |
| logger.info(f" Num examples = {len(train_dataset)}") |
| logger.info(f" Num batches each epoch = {len(train_dataloader)}") |
| logger.info(f" Num Epochs = {args.num_train_epochs}") |
| logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") |
| logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") |
| logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") |
| logger.info(f" Total optimization steps = {args.max_train_steps}") |
|
|
| global_step = 0 |
| first_epoch = 0 |
|
|
| |
|
|
| if args.resume_from_checkpoint: |
| if args.resume_from_checkpoint != "latest": |
| path = os.path.basename(args.resume_from_checkpoint) |
| else: |
| |
| dirs = os.listdir(args.output_dir) |
| if "checkpoint-current" in dirs: |
| path = "checkpoint-current" |
| dirs = [d for d in dirs if d.startswith("checkpoint") and d.endswith("0")] |
| dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) |
|
|
| else: |
| dirs = [d for d in dirs if d.startswith("checkpoint")] |
| dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) |
| path = dirs[-1] if len(dirs) > 0 else None |
|
|
| if path is None: |
| accelerator.print( |
| f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." |
| ) |
| args.resume_from_checkpoint = None |
| initial_global_step = 0 |
| else: |
| accelerator.print(f"Resuming from checkpoint {path}") |
| accelerator.load_state(os.path.join(args.output_dir, path)) |
| if path.split("-")[1] == "current": |
| global_step = int(dirs[-1].split("-")[1]) |
| else: |
| global_step = int(path.split("-")[1]) |
|
|
| initial_global_step = global_step |
| resume_global_step = global_step * args.gradient_accumulation_steps |
| first_epoch = global_step // num_update_steps_per_epoch |
| resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) |
| else: |
| initial_global_step = 0 |
|
|
| progress_bar = tqdm( |
| range(0, args.max_train_steps), |
| initial=initial_global_step, |
| desc="Steps", |
| disable=not accelerator.is_local_main_process, |
| ) |
|
|
| progress_bar.set_description("Steps") |
|
|
| for epoch in range(first_epoch, args.num_train_epochs): |
| with TorchTracemalloc() as tracemalloc: |
| for step, batch in enumerate(train_dataloader): |
| |
| if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: |
| if step % args.gradient_accumulation_steps == 0: |
| progress_bar.update(1) |
| if args.report_to == "wandb": |
| accelerator.print(progress_bar) |
| continue |
|
|
| with accelerator.accumulate(controlnet), accelerator.accumulate(unet): |
| |
| latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() |
| latents = latents * vae.config.scaling_factor |
|
|
| |
| noise = torch.randn_like(latents) |
| bsz = latents.shape[0] |
|
|
| |
| timesteps = torch.randint( |
| 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device |
| ) |
| timesteps = timesteps.long() |
|
|
| |
| |
| noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
|
|
| |
| encoder_hidden_states = text_encoder(batch["input_ids"])[0] |
|
|
| controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype) |
|
|
| |
| guided_hint = controlnet( |
| controlnet_cond=controlnet_image, |
| ) |
|
|
| |
| model_pred = unet( |
| noisy_latents, |
| timesteps, |
| guided_hint=guided_hint, |
| encoder_hidden_states=encoder_hidden_states, |
| ).sample |
|
|
| |
| if noise_scheduler.config.prediction_type == "epsilon": |
| target = noise |
| elif noise_scheduler.config.prediction_type == "v_prediction": |
| target = noise_scheduler.get_velocity(latents, noise, timesteps) |
| else: |
| raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
|
|
| loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
|
|
| accelerator.backward(loss) |
|
|
| if accelerator.sync_gradients: |
| params_to_clip = ( |
| itertools.chain(controlnet.parameters(), text_encoder.parameters()) |
| if args.train_text_encoder |
| else itertools.chain( |
| controlnet.parameters(), |
| ) |
| ) |
|
|
| accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) |
|
|
| optimizer.step() |
| lr_scheduler.step() |
| optimizer.zero_grad(set_to_none=args.set_grads_to_none) |
|
|
| |
| if accelerator.sync_gradients: |
| progress_bar.update(1) |
| if args.report_to == "wandb": |
| accelerator.print(progress_bar) |
| global_step += 1 |
|
|
| step_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") |
|
|
| if accelerator.is_main_process: |
| if global_step % args.validation_steps == 0 or global_step == 1: |
| logger.info(f"Running validation... \n Generating {args.num_validation_images} images.") |
| logger.info("Running validation... ") |
|
|
| with torch.no_grad(): |
| log_validation(val_dataset, text_encoder, unet, controlnet, args, accelerator) |
|
|
| if global_step % args.checkpointing_steps == 0: |
| save_adaptor(accelerator, step_save_path, {"controlnet": controlnet, "unet": unet}) |
|
|
| |
| if args.train_text_encoder: |
| save_adaptor(accelerator, step_save_path, {"text_encoder": text_encoder}) |
|
|
| accelerator.save_state(step_save_path) |
|
|
| logger.info(f"Saved {global_step} state to {step_save_path}") |
| logger.info(f"Saved current state to {step_save_path}") |
|
|
| logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} |
| progress_bar.set_postfix(**logs) |
| accelerator.log(logs, step=global_step) |
|
|
| if global_step >= args.max_train_steps: |
| break |
|
|
| |
| accelerator.print(f"GPU Memory before entering the train : {b2mb(tracemalloc.begin)}") |
| accelerator.print(f"GPU Memory consumed at the end of the train (end-begin): {tracemalloc.used}") |
| accelerator.print(f"GPU Peak Memory consumed during the train (max-begin): {tracemalloc.peaked}") |
| accelerator.print( |
| f"GPU Total Peak Memory consumed during the train (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}" |
| ) |
|
|
| accelerator.print(f"CPU Memory before entering the train : {b2mb(tracemalloc.cpu_begin)}") |
| accelerator.print(f"CPU Memory consumed at the end of the train (end-begin): {tracemalloc.cpu_used}") |
| accelerator.print(f"CPU Peak Memory consumed during the train (max-begin): {tracemalloc.cpu_peaked}") |
| accelerator.print( |
| f"CPU Total Peak Memory consumed during the train (max): {tracemalloc.cpu_peaked + b2mb(tracemalloc.cpu_begin)}" |
| ) |
|
|
| |
| accelerator.wait_for_everyone() |
| accelerator.end_training() |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| main(args) |
|
|