ours / layout2image_multi_lora /train_lora.py
diing's picture
Add files using upload-large-folder tool
75b1a45 verified
#!/usr/bin/env python
# coding=utf-8
import argparse
import contextlib
import gc
import logging
import math
import os
import random
import shutil
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from packaging import version
from tqdm.auto import tqdm
import diffusers
from diffusers import (
AutoencoderKL,
DDPMScheduler,
DiffusionPipeline,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params
from diffusers.utils import (
convert_state_dict_to_diffusers,
is_wandb_available,
)
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict
from diffusers.utils import convert_unet_state_dict_to_peft
from transformers import CLIPTextModel, CLIPTokenizer
from MyDataset import MyDataset
from torch.utils.data import Subset
logger = get_logger(__name__)
if is_wandb_available():
import wandb
os.environ["WANDB_API_KEY"] = "b539ac2bc1840d6e83a720e406ddda45c907ab94"
wandb.init(project="train_lora")
def log_validation(pipeline, args, accelerator, step, is_final_validation=False):
if args.validation_prompt is None:
return None
phase_name = "test" if is_final_validation else "validation"
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
generator = None
if args.seed is not None:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
all_logs = []
prompts = args.validation_prompt if isinstance(args.validation_prompt, list) else [args.validation_prompt]
autocast_ctx = contextlib.nullcontext() if torch.backends.mps.is_available() else torch.autocast(accelerator.device.type)
with autocast_ctx:
for prompt in prompts:
images = []
for _ in range(args.num_validation_images):
images.append(
pipeline(
prompt,
num_inference_steps=args.validation_num_inference_steps,
generator=generator,
).images[0]
)
all_logs.append((prompt, images))
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
for prompt, images in all_logs:
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images(f"{phase_name}/{prompt}", np_images, step, dataformats="NHWC")
elif tracker.name == "wandb":
payload = []
for prompt, images in all_logs:
for i, img in enumerate(images):
payload.append(wandb.Image(img, caption=f"{prompt} | {i}"))
tracker.log({phase_name: payload}, step=step)
gc.collect()
torch.cuda.empty_cache()
return all_logs
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="SD1.4 LoRA fine-tuning (UNet LoRA only).")
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, required=True)
parser.add_argument("--revision", type=str, default=None)
parser.add_argument("--variant", type=str, default=None)
parser.add_argument("--output_dir", type=str, default="sd14-lora")
parser.add_argument("--cache_dir", type=str, default=None)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--resolution", type=int, default=512)
parser.add_argument("--train_batch_size", type=int, default=4)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument("--max_train_steps", type=int, default=None)
parser.add_argument("--checkpointing_steps", type=int, default=1000)
parser.add_argument("--checkpoints_total_limit", type=int, default=None)
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--gradient_checkpointing", action="store_true")
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--scale_lr", action="store_true", default=False)
parser.add_argument("--lr_scheduler", type=str, default="constant")
parser.add_argument("--lr_warmup_steps", type=int, default=500)
parser.add_argument("--lr_num_cycles", type=int, default=1)
parser.add_argument("--lr_power", type=float, default=1.0)
parser.add_argument("--use_8bit_adam", action="store_true")
parser.add_argument("--dataloader_num_workers", type=int, default=0)
parser.add_argument("--adam_beta1", type=float, default=0.9)
parser.add_argument("--adam_beta2", type=float, default=0.999)
parser.add_argument("--adam_weight_decay", type=float, default=1e-2)
parser.add_argument("--adam_epsilon", type=float, default=1e-8)
parser.add_argument("--max_grad_norm", type=float, default=1.0)
parser.add_argument("--logging_dir", type=str, default="logs")
parser.add_argument("--allow_tf32", action="store_true")
parser.add_argument("--report_to", type=str, default="tensorboard")
parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"])
parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true")
parser.add_argument("--set_grads_to_none", action="store_true")
# dataset (kept for your MyDataset)
parser.add_argument("--dataset_name", type=str, default=None)
parser.add_argument("--dataset_config_name", type=str, default=None)
parser.add_argument("--train_data_dir", type=str, default=None)
parser.add_argument("--train_data_prompt", type=str, default=None)
parser.add_argument("--image_column", type=str, default="image")
parser.add_argument("--caption_column", type=str, default="text")
parser.add_argument("--max_train_samples", type=int, default=None)
parser.add_argument("--proportion_empty_prompts", type=float, default=0.0)
# validation (prompt-only)
parser.add_argument("--validation_prompt", type=str, default=None, nargs="+")
parser.add_argument("--num_validation_images", type=int, default=4)
parser.add_argument("--validation_steps", type=int, default=100)
parser.add_argument("--validation_num_inference_steps", type=int, default=30)
parser.add_argument("--tracker_project_name", type=str, default="train_sd14_lora")
# LoRA
parser.add_argument("--rank", type=int, default=4)
parser.add_argument("--lora_alpha", type=int, default=None)
args = parser.parse_args(input_args) if input_args is not None else parser.parse_args()
if args.train_data_dir is None and args.dataset_name is None:
raise ValueError("Specify either --train_data_dir or --dataset_name (your MyDataset can use --train_data_dir).")
if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
raise ValueError("--proportion_empty_prompts must be in [0, 1].")
if args.resolution % 8 != 0:
raise ValueError("--resolution must be divisible by 8.")
if args.lora_alpha is None:
args.lora_alpha = args.rank
return args
# def collate_fn(examples):
# pixel_values = torch.stack([ex["pixel_values"] for ex in examples]).to(memory_format=torch.contiguous_format).float()
# input_ids = torch.stack([ex["input_ids"] for ex in examples])
# return {"pixel_values": pixel_values, "input_ids": input_ids}
def collate_fn(examples):
examples = [ex for ex in examples if ex is not None]
if len(examples) == 0:
return None
pixel_values = torch.stack([ex["pixel_values"] for ex in examples]).to(memory_format=torch.contiguous_format).float()
input_ids = torch.stack([ex["input_ids"] for ex in examples])
return {"pixel_values": pixel_values, "input_ids": input_ids}
def main(args):
if args.report_to == "wandb":
# keep as-is if you use wandb login
pass
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
if torch.backends.mps.is_available():
accelerator.native_amp = False
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
if args.seed is not None:
set_seed(args.seed)
if accelerator.is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# Load scheduler / tokenizer / models (SD1.4)
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision)
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant)
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant)
# Freeze base params
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
# LoRA on UNet attention projections
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.lora_alpha,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
unet.add_adapter(unet_lora_config)
if args.mixed_precision == "fp16":
cast_training_params(unet, dtype=torch.float32)
if args.enable_xformers_memory_efficient_attention:
if not is_xformers_available():
raise ValueError("xformers is not available.")
import xformers # noqa: F401
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning("xFormers 0.0.16 may be unstable for training; upgrade to >=0.0.17 if issues occur.")
unet.enable_xformers_memory_efficient_attention()
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
# Optimizer (LoRA params only)
lora_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
if args.scale_lr:
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
if args.use_8bit_adam:
import bitsandbytes as bnb
optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
optimizer = optimizer_cls(
lora_layers,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# Dataset / Dataloader (use your MyDataset)
train_dataset = MyDataset(args, tokenizer)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
# Save/load hooks: save only LoRA weights inside accelerator checkpoints
def save_model_hook(models, weights, output_dir):
if not accelerator.is_main_process:
return
unet_lora_layers_to_save = None
for model in models:
if isinstance(model, type(unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model)
else:
raise ValueError(f"Unexpected model class in save hook: {model.__class__}")
weights.pop()
StableDiffusionPipeline.save_lora_weights(
save_directory=output_dir,
unet_lora_layers=unet_lora_layers_to_save,
safe_serialization=True,
)
def load_model_hook(models, input_dir):
unet_ = None
while len(models) > 0:
model = models.pop()
if isinstance(model, type(unwrap_model(unet))):
unet_ = model
else:
raise ValueError(f"Unexpected model class in load hook: {model.__class__}")
lora_state_dict, _ = StableDiffusionPipeline.lora_state_dict(input_dir)
unet_state_dict = {k.replace("unet.", ""): v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(f"Unexpected keys when loading LoRA: {unexpected_keys}")
if accelerator.mixed_precision == "fp16":
cast_training_params([unet_], dtype=torch.float32)
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# Prepare
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# Recompute steps/epochs after prepare
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if accelerator.is_main_process:
tracker_config = dict(vars(args))
accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Resume
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
dirs = [d for d in os.listdir(args.output_dir) if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new run.")
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
disable=not accelerator.is_local_main_process,
)
# Train loop
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
for step, batch in enumerate(train_dataloader):
if batch is None:
continue
with accelerator.accumulate(unet):
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(lora_layers, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if accelerator.is_main_process and global_step % args.checkpointing_steps == 0:
if args.checkpoints_total_limit is not None:
checkpoints = [d for d in os.listdir(args.output_dir) if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
for ckpt in checkpoints[:num_to_remove]:
shutil.rmtree(os.path.join(args.output_dir, ckpt), ignore_errors=True)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
if accelerator.is_main_process and args.validation_prompt is not None and global_step % args.validation_steps == 0:
pipe = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=unwrap_model(unet),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
safety_checker=None
)
log_validation(pipe, args, accelerator, global_step, is_final_validation=False)
del pipe
torch.cuda.empty_cache()
if global_step >= args.max_train_steps:
break
if global_step >= args.max_train_steps:
break
# Save final LoRA weights
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet_ = unwrap_model(unet).to(torch.float32)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet_))
StableDiffusionPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
safe_serialization=True,
)
if args.validation_prompt is not None:
pipe = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
safety_checker=None
)
pipe.load_lora_weights(args.output_dir)
log_validation(pipe, args, accelerator, global_step, is_final_validation=True)
del pipe
torch.cuda.empty_cache()
accelerator.end_training()
if __name__ == "__main__":
args = parse_args()
main(args)