#!/usr/bin/env python # coding=utf-8 import argparse import contextlib import gc import logging import math import os import random import shutil from pathlib import Path import numpy as np import torch import torch.nn.functional as F import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from packaging import version from tqdm.auto import tqdm import diffusers from diffusers import ( AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel, ) from diffusers.optimization import get_scheduler from diffusers.training_utils import cast_training_params from diffusers.utils import ( convert_state_dict_to_diffusers, is_wandb_available, ) from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module from peft import LoraConfig from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict from diffusers.utils import convert_unet_state_dict_to_peft from transformers import CLIPTextModel, CLIPTokenizer from MyDataset import MyDataset from torch.utils.data import Subset logger = get_logger(__name__) if is_wandb_available(): import wandb os.environ["WANDB_API_KEY"] = "b539ac2bc1840d6e83a720e406ddda45c907ab94" wandb.init(project="train_lora") def log_validation(pipeline, args, accelerator, step, is_final_validation=False): if args.validation_prompt is None: return None phase_name = "test" if is_final_validation else "validation" pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) generator = None if args.seed is not None: generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) all_logs = [] prompts = args.validation_prompt if isinstance(args.validation_prompt, list) else [args.validation_prompt] autocast_ctx = contextlib.nullcontext() if torch.backends.mps.is_available() else torch.autocast(accelerator.device.type) with autocast_ctx: for prompt in prompts: images = [] for _ in range(args.num_validation_images): images.append( pipeline( prompt, num_inference_steps=args.validation_num_inference_steps, generator=generator, ).images[0] ) all_logs.append((prompt, images)) for tracker in accelerator.trackers: if tracker.name == "tensorboard": for prompt, images in all_logs: np_images = np.stack([np.asarray(img) for img in images]) tracker.writer.add_images(f"{phase_name}/{prompt}", np_images, step, dataformats="NHWC") elif tracker.name == "wandb": payload = [] for prompt, images in all_logs: for i, img in enumerate(images): payload.append(wandb.Image(img, caption=f"{prompt} | {i}")) tracker.log({phase_name: payload}, step=step) gc.collect() torch.cuda.empty_cache() return all_logs def parse_args(input_args=None): parser = argparse.ArgumentParser(description="SD1.4 LoRA fine-tuning (UNet LoRA only).") parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, required=True) parser.add_argument("--revision", type=str, default=None) parser.add_argument("--variant", type=str, default=None) parser.add_argument("--output_dir", type=str, default="sd14-lora") parser.add_argument("--cache_dir", type=str, default=None) parser.add_argument("--seed", type=int, default=None) parser.add_argument("--resolution", type=int, default=512) parser.add_argument("--train_batch_size", type=int, default=4) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument("--max_train_steps", type=int, default=None) parser.add_argument("--checkpointing_steps", type=int, default=1000) parser.add_argument("--checkpoints_total_limit", type=int, default=None) parser.add_argument("--resume_from_checkpoint", type=str, default=None) parser.add_argument("--gradient_accumulation_steps", type=int, default=1) parser.add_argument("--gradient_checkpointing", action="store_true") parser.add_argument("--learning_rate", type=float, default=1e-4) parser.add_argument("--scale_lr", action="store_true", default=False) parser.add_argument("--lr_scheduler", type=str, default="constant") parser.add_argument("--lr_warmup_steps", type=int, default=500) parser.add_argument("--lr_num_cycles", type=int, default=1) parser.add_argument("--lr_power", type=float, default=1.0) parser.add_argument("--use_8bit_adam", action="store_true") parser.add_argument("--dataloader_num_workers", type=int, default=0) parser.add_argument("--adam_beta1", type=float, default=0.9) parser.add_argument("--adam_beta2", type=float, default=0.999) parser.add_argument("--adam_weight_decay", type=float, default=1e-2) parser.add_argument("--adam_epsilon", type=float, default=1e-8) parser.add_argument("--max_grad_norm", type=float, default=1.0) parser.add_argument("--logging_dir", type=str, default="logs") parser.add_argument("--allow_tf32", action="store_true") parser.add_argument("--report_to", type=str, default="tensorboard") parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"]) parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true") parser.add_argument("--set_grads_to_none", action="store_true") # dataset (kept for your MyDataset) parser.add_argument("--dataset_name", type=str, default=None) parser.add_argument("--dataset_config_name", type=str, default=None) parser.add_argument("--train_data_dir", type=str, default=None) parser.add_argument("--train_data_prompt", type=str, default=None) parser.add_argument("--image_column", type=str, default="image") parser.add_argument("--caption_column", type=str, default="text") parser.add_argument("--max_train_samples", type=int, default=None) parser.add_argument("--proportion_empty_prompts", type=float, default=0.0) # validation (prompt-only) parser.add_argument("--validation_prompt", type=str, default=None, nargs="+") parser.add_argument("--num_validation_images", type=int, default=4) parser.add_argument("--validation_steps", type=int, default=100) parser.add_argument("--validation_num_inference_steps", type=int, default=30) parser.add_argument("--tracker_project_name", type=str, default="train_sd14_lora") # LoRA parser.add_argument("--rank", type=int, default=4) parser.add_argument("--lora_alpha", type=int, default=None) args = parser.parse_args(input_args) if input_args is not None else parser.parse_args() if args.train_data_dir is None and args.dataset_name is None: raise ValueError("Specify either --train_data_dir or --dataset_name (your MyDataset can use --train_data_dir).") if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: raise ValueError("--proportion_empty_prompts must be in [0, 1].") if args.resolution % 8 != 0: raise ValueError("--resolution must be divisible by 8.") if args.lora_alpha is None: args.lora_alpha = args.rank return args # def collate_fn(examples): # pixel_values = torch.stack([ex["pixel_values"] for ex in examples]).to(memory_format=torch.contiguous_format).float() # input_ids = torch.stack([ex["input_ids"] for ex in examples]) # return {"pixel_values": pixel_values, "input_ids": input_ids} def collate_fn(examples): examples = [ex for ex in examples if ex is not None] if len(examples) == 0: return None pixel_values = torch.stack([ex["pixel_values"] for ex in examples]).to(memory_format=torch.contiguous_format).float() input_ids = torch.stack([ex["input_ids"] for ex in examples]) return {"pixel_values": pixel_values, "input_ids": input_ids} def main(args): if args.report_to == "wandb": # keep as-is if you use wandb login pass logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, ) if torch.backends.mps.is_available(): accelerator.native_amp = False logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() if args.seed is not None: set_seed(args.seed) if accelerator.is_main_process: os.makedirs(args.output_dir, exist_ok=True) def unwrap_model(model): model = accelerator.unwrap_model(model) model = model._orig_mod if is_compiled_module(model) else model return model # Load scheduler / tokenizer / models (SD1.4) noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision) text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision) vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant) unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant) # Freeze base params unet.requires_grad_(False) vae.requires_grad_(False) text_encoder.requires_grad_(False) weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 unet.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) # LoRA on UNet attention projections unet_lora_config = LoraConfig( r=args.rank, lora_alpha=args.lora_alpha, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"], ) unet.add_adapter(unet_lora_config) if args.mixed_precision == "fp16": cast_training_params(unet, dtype=torch.float32) if args.enable_xformers_memory_efficient_attention: if not is_xformers_available(): raise ValueError("xformers is not available.") import xformers # noqa: F401 xformers_version = version.parse(xformers.__version__) if xformers_version == version.parse("0.0.16"): logger.warning("xFormers 0.0.16 may be unstable for training; upgrade to >=0.0.17 if issues occur.") unet.enable_xformers_memory_efficient_attention() if args.gradient_checkpointing: unet.enable_gradient_checkpointing() if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True # Optimizer (LoRA params only) lora_layers = list(filter(lambda p: p.requires_grad, unet.parameters())) if args.scale_lr: args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes if args.use_8bit_adam: import bitsandbytes as bnb optimizer_cls = bnb.optim.AdamW8bit else: optimizer_cls = torch.optim.AdamW optimizer = optimizer_cls( lora_layers, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) # Dataset / Dataloader (use your MyDataset) train_dataset = MyDataset(args, tokenizer) train_dataloader = torch.utils.data.DataLoader( train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers, ) overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, num_training_steps=args.max_train_steps * accelerator.num_processes, num_cycles=args.lr_num_cycles, power=args.lr_power, ) # Save/load hooks: save only LoRA weights inside accelerator checkpoints def save_model_hook(models, weights, output_dir): if not accelerator.is_main_process: return unet_lora_layers_to_save = None for model in models: if isinstance(model, type(unwrap_model(unet))): unet_lora_layers_to_save = get_peft_model_state_dict(model) else: raise ValueError(f"Unexpected model class in save hook: {model.__class__}") weights.pop() StableDiffusionPipeline.save_lora_weights( save_directory=output_dir, unet_lora_layers=unet_lora_layers_to_save, safe_serialization=True, ) def load_model_hook(models, input_dir): unet_ = None while len(models) > 0: model = models.pop() if isinstance(model, type(unwrap_model(unet))): unet_ = model else: raise ValueError(f"Unexpected model class in load hook: {model.__class__}") lora_state_dict, _ = StableDiffusionPipeline.lora_state_dict(input_dir) unet_state_dict = {k.replace("unet.", ""): v for k, v in lora_state_dict.items() if k.startswith("unet.")} unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") if incompatible_keys is not None: unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) if unexpected_keys: logger.warning(f"Unexpected keys when loading LoRA: {unexpected_keys}") if accelerator.mixed_precision == "fp16": cast_training_params([unet_], dtype=torch.float32) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) # Prepare unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) # Recompute steps/epochs after prepare num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) if accelerator.is_main_process: tracker_config = dict(vars(args)) accelerator.init_trackers(args.tracker_project_name, config=tracker_config) total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 first_epoch = 0 # Resume if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) else: dirs = [d for d in os.listdir(args.output_dir) if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None if path is None: accelerator.print(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new run.") args.resume_from_checkpoint = None initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch else: initial_global_step = 0 progress_bar = tqdm( range(0, args.max_train_steps), initial=initial_global_step, desc="Steps", disable=not accelerator.is_local_main_process, ) # Train loop for epoch in range(first_epoch, args.num_train_epochs): unet.train() for step, batch in enumerate(train_dataloader): if batch is None: continue with accelerator.accumulate(unet): latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor noise = torch.randn_like(latents) bsz = latents.shape[0] timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long() noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0] model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(lora_layers, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=args.set_grads_to_none) if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) if accelerator.is_main_process and global_step % args.checkpointing_steps == 0: if args.checkpoints_total_limit is not None: checkpoints = [d for d in os.listdir(args.output_dir) if d.startswith("checkpoint")] checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) if len(checkpoints) >= args.checkpoints_total_limit: num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 for ckpt in checkpoints[:num_to_remove]: shutil.rmtree(os.path.join(args.output_dir, ckpt), ignore_errors=True) save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) if accelerator.is_main_process and args.validation_prompt is not None and global_step % args.validation_steps == 0: pipe = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=unwrap_model(unet), revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, safety_checker=None ) log_validation(pipe, args, accelerator, global_step, is_final_validation=False) del pipe torch.cuda.empty_cache() if global_step >= args.max_train_steps: break if global_step >= args.max_train_steps: break # Save final LoRA weights accelerator.wait_for_everyone() if accelerator.is_main_process: unet_ = unwrap_model(unet).to(torch.float32) unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet_)) StableDiffusionPipeline.save_lora_weights( save_directory=args.output_dir, unet_lora_layers=unet_lora_state_dict, safe_serialization=True, ) if args.validation_prompt is not None: pipe = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, safety_checker=None ) pipe.load_lora_weights(args.output_dir) log_validation(pipe, args, accelerator, global_step, is_final_validation=True) del pipe torch.cuda.empty_cache() accelerator.end_training() if __name__ == "__main__": args = parse_args() main(args)