qwenillustrious / train /train_qwen_illustrious.py
lsmpp's picture
Add files using upload-large-folder tool
d926b4c verified
#!/usr/bin/env python3
"""
训练 QwenIllustrious 模型 - 结合 Qwen 文本编码器和 SDXL UNet
"""
import argparse
import logging
import math
import os
import random
import sys
from pathlib import Path
from typing import Dict, List, Tuple
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
import numpy as np
import wandb
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from tqdm.auto import tqdm
from transformers import AutoTokenizer
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version
from peft import LoraConfig, get_peft_model, TaskType
# 导入项目组件
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from arch import QwenTextEncoder, QwenEmbeddingAdapter
from arch.data_loader import QwenIllustriousDataset, collate_fn
from arch.model_loader import load_qwen_model, load_unet_from_safetensors, load_vae_from_safetensors, create_scheduler
# 检查最低版本
check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Train QwenIllustrious model")
# Model arguments
parser.add_argument(
"--qwen_model_path",
type=str,
default="models/Qwen3-Embedding-0.6B",
help="Path to Qwen text encoder model"
)
parser.add_argument(
"--unet_model_path",
type=str,
default="models/extracted_components/waiNSFWIllustrious_v140_unet.safetensors",
help="Path to UNet model"
)
parser.add_argument(
"--unet_config_path",
type=str,
default="models/extracted_components/waiNSFWIllustrious_v140_unet_config.json",
help="Path to SDXL model config file"
)
parser.add_argument(
"--vae_model_path",
type=str,
default="models/extracted_components/waiNSFWIllustrious_v140_vae.safetensors",
help="Path to VAE model (if different from SDXL)"
)
parser.add_argument(
"--vae_config_path",
type=str,
default="models/extracted_components/waiNSFWIllustrious_v140_vae_config.json",
help="Path to VAE config file"
)
# Dataset arguments
parser.add_argument(
"--dataset_path",
type=str,
default='illustrious_generated',
help="Path to illustrious_generated dataset"
)
parser.add_argument(
"--no_precompute_embeddings",
action="store_false",
dest="precompute_embeddings",
help="Disable precomputing and caching Qwen embeddings and VAE latents"
)
parser.set_defaults(precompute_embeddings=True)
parser.add_argument(
"--cache_dir",
type=str,
default="./illustrious_generated/cache",
help="Directory to store precomputed embeddings"
)
# Training arguments
parser.add_argument(
"--output_dir",
type=str,
default="./qwen_illustrious_output",
help="Output directory for trained model"
)
parser.add_argument(
"--train_batch_size",
type=int,
default=1,
help="Batch size for training"
)
parser.add_argument(
"--num_train_epochs",
type=int,
default=10,
help="Number of training epochs"
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Learning rate"
)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Maximum number of training steps"
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of gradient accumulation steps"
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Enable gradient checkpointing"
)
parser.add_argument(
"--mixed_precision",
type=str,
default="fp16",
choices=["no", "fp16", "bf16"],
help="Mixed precision training"
)
# LoRA arguments
parser.add_argument(
"--lora_rank",
type=int,
default=64,
help="LoRA rank for SDXL UNet cross attention"
)
parser.add_argument(
"--lora_alpha",
type=int,
default=64,
help="LoRA alpha"
)
parser.add_argument(
"--lora_dropout",
type=float,
default=0.1,
help="LoRA dropout"
)
# Other arguments
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed"
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help="Logging directory"
)
parser.add_argument(
"--report_to",
type=str,
default="wandb",
help="Logging service (tensorboard, wandb, or all)"
)
parser.add_argument(
"--wandb_project",
type=str,
default="qwen-illustrious",
help="Wandb project name"
)
parser.add_argument(
"--wandb_run_name",
type=str,
default=None,
help="Wandb run name (optional)"
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=25000,
help="Save checkpoint every N steps"
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="Path to checkpoint to resume from"
)
parser.add_argument(
"--validation_epochs",
type=int,
default=1,
help="Run validation every N epochs"
)
parser.add_argument(
"--validation_prompts",
type=str,
nargs="+",
default=[
"A beautiful anime girl in a garden",
"Two characters having a conversation",
"A magical fantasy scene"
],
help="Validation prompts"
)
return parser.parse_args()
def setup_models(args, accelerator):
"""Setup and configure all models"""
logger.info("Loading models...")
# Load Qwen text encoder
# qwen_text_encoder = load_qwen_model(args.qwen_model_path)
qwen_text_encoder = QwenTextEncoder(
model_path=args.qwen_model_path,
device='cuda' if torch.cuda.is_available() else 'cpu',
max_length=512, # Default max length for Qwen
freeze_encoder=True # Freeze encoder parameters
)
# Load SDXL components
vae = load_vae_from_safetensors(args.vae_model_path, args.vae_config_path)
unet = load_unet_from_safetensors(args.unet_model_path, args.unet_config_path)
# Load scheduler
noise_scheduler = create_scheduler()
# Create adapter
adapter = QwenEmbeddingAdapter()
# Configure LoRA for UNet cross attention
logger.info(f"Setting up LoRA with rank={args.lora_rank}, alpha={args.lora_alpha}")
# Define target modules for cross attention to_k and to_v
target_modules = []
for name, module in unet.named_modules():
if "attn2" in name and ("to_k" in name or "to_v" in name):
target_modules.append(name)
if not target_modules:
logger.warning("No cross attention to_k/to_v modules found. Using default modules.")
target_modules = ["to_k", "to_v"]
logger.info(f"Applying LoRA to modules: {target_modules}")
lora_config = LoraConfig(
r=args.lora_rank,
lora_alpha=args.lora_alpha,
target_modules=target_modules,
lora_dropout=args.lora_dropout,
bias="none",
)
# Apply LoRA to UNet
unet = get_peft_model(unet, lora_config)
# Set requires_grad
vae.requires_grad_(False)
qwen_text_encoder.requires_grad_(False)
unet.requires_grad_(False)
# Enable gradients for adapter and LoRA parameters
adapter.requires_grad_(True)
for name, param in unet.named_parameters():
if "lora" in name:
param.requires_grad_(True)
# Log trainable parameters
total_params = sum(p.numel() for p in unet.parameters())
trainable_params = sum(p.numel() for p in unet.parameters() if p.requires_grad)
adapter_params = sum(p.numel() for p in adapter.parameters())
logger.info(f"UNet total parameters: {total_params:,}")
logger.info(f"UNet trainable parameters: {trainable_params:,}")
logger.info(f"Adapter parameters: {adapter_params:,}")
logger.info(f"Total trainable parameters: {trainable_params + adapter_params:,}")
return qwen_text_encoder, unet, vae, noise_scheduler, adapter
def setup_dataset(args, qwen_text_encoder, vae, accelerator):
"""Setup dataset with optional precomputation"""
logger.info("Setting up dataset...")
dataset = QwenIllustriousDataset(
dataset_path=args.dataset_path,
qwen_text_encoder=qwen_text_encoder if args.precompute_embeddings else None,
vae=vae if args.precompute_embeddings else None,
cache_dir=args.cache_dir if args.precompute_embeddings else None,
precompute_embeddings=args.precompute_embeddings
)
if args.precompute_embeddings:
logger.info("Precomputing embeddings...")
dataset.precompute_all(accelerator.device)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=args.train_batch_size,
shuffle=True,
num_workers=4,
pin_memory=True,
collate_fn=collate_fn
)
return dataset, dataloader
def training_step(batch, unet, adapter, noise_scheduler, vae, qwen_text_encoder, accelerator, args):
"""Single training step"""
# Get batch data
if args.precompute_embeddings:
latents = batch["latents"].to(accelerator.device)
# For precomputed embeddings, we need to pass them through the adapter
qwen_text_embeddings = batch["text_embeddings"].to(accelerator.device)
qwen_pooled_embeddings = batch["pooled_embeddings"].to(accelerator.device)
# Project embeddings through adapter
text_embeddings = adapter.forward_text_embeddings(qwen_text_embeddings)
pooled_embeddings = adapter.forward_pooled_embeddings(qwen_pooled_embeddings)
else:
images = batch["images"].to(accelerator.device)
prompts = batch["prompts"]
# Encode images to latents
with torch.no_grad():
latents = vae.encode(images).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Encode text with Qwen
with torch.no_grad():
qwen_embeddings = qwen_text_encoder.encode_prompts(prompts, do_classifier_free_guidance=False)
# Project embeddings through adapter
text_embeddings = adapter.forward_text_embeddings(qwen_embeddings[0])
pooled_embeddings = adapter.forward_pooled_embeddings(qwen_embeddings[1])
# Sample noise and timesteps
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,),
device=latents.device, dtype=torch.long
)
# Add noise to latents
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Prepare cross attention inputs
encoder_hidden_states = text_embeddings
# Prepare added condition kwargs for SDXL
add_time_ids = torch.zeros((bsz, 6), device=latents.device) # Dummy time IDs
added_cond_kwargs = {
"text_embeds": pooled_embeddings,
"time_ids": add_time_ids
}
# Forward pass through UNet
model_pred = unet(
noisy_latents,
timesteps,
encoder_hidden_states,
added_cond_kwargs=added_cond_kwargs,
return_dict=False
)[0]
# Compute loss
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
return loss
def validate_model(args, qwen_text_encoder, unet, adapter, vae, accelerator, epoch):
"""Run validation"""
logger.info(f"Running validation at epoch {epoch}")
# TODO: Implement validation logic
# For now, just log that validation ran
logger.info("Validation completed")
return {}
def main():
args = parse_args()
# Initialize wandb if using it
if args.report_to in ["wandb", "all"]:
wandb.init(
project=args.wandb_project,
name=args.wandb_run_name,
config=vars(args)
)
# Setup accelerator
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(
project_dir=args.output_dir,
logging_dir=logging_dir
)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
# Set seed
if args.seed is not None:
set_seed(args.seed)
# Create output directory
if accelerator.is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
if args.precompute_embeddings:
os.makedirs(args.cache_dir, exist_ok=True)
# Setup models
qwen_text_encoder, unet, vae, noise_scheduler, adapter = setup_models(args, accelerator)
# Setup dataset
dataset, dataloader = setup_dataset(args, qwen_text_encoder, vae, accelerator)
# Setup optimizer
trainable_params = list(adapter.parameters())
for param in unet.parameters():
if param.requires_grad:
trainable_params.append(param)
optimizer = torch.optim.AdamW(
trainable_params,
lr=args.learning_rate,
betas=(0.9, 0.999),
weight_decay=0.01,
eps=1e-8,
)
# Calculate training steps
num_update_steps_per_epoch = math.ceil(len(dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Setup scheduler
lr_scheduler = get_scheduler(
"cosine",
optimizer=optimizer,
num_warmup_steps=500,
num_training_steps=args.max_train_steps,
)
# Prepare for training
unet, adapter, optimizer, dataloader, lr_scheduler = accelerator.prepare(
unet, adapter, optimizer, dataloader, lr_scheduler
)
# Move other models to device
qwen_text_encoder.to('cpu')
vae.to('cpu')
# Initialize tracking
if accelerator.is_main_process:
tracker_config = vars(args)
accelerator.init_trackers(args.wandb_project, config=tracker_config)
# Training loop
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size = {args.train_batch_size * accelerator.num_processes}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Resume from checkpoint if specified
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting new training.")
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
first_epoch = global_step // num_update_steps_per_epoch
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=global_step,
desc="Steps",
disable=not accelerator.is_local_main_process,
)
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
adapter.train()
train_loss = 0.0
epoch_loss = 0.0
num_batches = 0
for step, batch in enumerate(dataloader):
with accelerator.accumulate(unet):
loss = training_step(batch, unet, adapter, noise_scheduler, vae, qwen_text_encoder, accelerator, args)
# Backward pass
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(trainable_params, 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
# Logging
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
epoch_loss += avg_loss.item()
num_batches += 1
# Log metrics to both accelerator and wandb
log_dict = {
"train/step_loss": avg_loss.item(),
"train/learning_rate": lr_scheduler.get_last_lr()[0],
"train/epoch": epoch,
"train/global_step": global_step
}
accelerator.log(log_dict, step=global_step)
# Additional wandb logging
if args.report_to in ["wandb", "all"] and accelerator.is_main_process:
wandb.log(log_dict, step=global_step)
train_loss = 0.0
# Save checkpoint
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
# Log epoch statistics
if num_batches > 0 and accelerator.is_main_process:
avg_epoch_loss = epoch_loss / num_batches
epoch_log_dict = {
"train/epoch_loss": avg_epoch_loss,
"train/epoch_num": epoch
}
accelerator.log(epoch_log_dict, step=global_step)
if args.report_to in ["wandb", "all"]:
wandb.log(epoch_log_dict, step=global_step)
logger.info(f"Epoch {epoch} - Average Loss: {avg_epoch_loss:.4f}")
# Validation
if epoch % args.validation_epochs == 0:
validation_metrics = validate_model(
args, qwen_text_encoder, unet, adapter, vae, accelerator, epoch
)
if validation_metrics:
accelerator.log(validation_metrics, step=global_step)
if args.report_to in ["wandb", "all"] and accelerator.is_main_process:
wandb.log(validation_metrics, step=global_step)
# Save final model
accelerator.wait_for_everyone()
if accelerator.is_main_process:
from safetensors.torch import save_file
from peft import get_peft_model_state_dict
logger.info("Saving trained models...")
# Save adapter in safetensor format
adapter_save_path = os.path.join(args.output_dir, "adapter")
os.makedirs(adapter_save_path, exist_ok=True)
adapter_state_dict = adapter.state_dict()
save_file(adapter_state_dict, os.path.join(adapter_save_path, "adapter.safetensors"))
logger.info(f"Adapter saved to {adapter_save_path}/adapter.safetensors")
# Save LoRA weights only (in safetensor format)
lora_save_path = os.path.join(args.output_dir, "lora_weights")
os.makedirs(lora_save_path, exist_ok=True)
# Get only LoRA state dict
lora_state_dict = get_peft_model_state_dict(unet)
save_file(lora_state_dict, os.path.join(lora_save_path, "lora_weights.safetensors"))
logger.info(f"LoRA weights saved to {lora_save_path}/lora_weights.safetensors")
# Save LoRA config
lora_config_path = os.path.join(lora_save_path, "adapter_config.json")
unet.peft_config['default'].save_pretrained(lora_save_path)
logger.info(f"LoRA config saved to {lora_save_path}/adapter_config.json")
# Save full UNet with fused LoRA weights
logger.info("Fusing LoRA weights into UNet...")
unet_fused_save_path = os.path.join(args.output_dir, "unet_fused")
os.makedirs(unet_fused_save_path, exist_ok=True)
# Create a copy of the original UNet and merge LoRA weights
from diffusers import UNet2DConditionModel
unet_base = unet
# Merge LoRA weights into base model
from peft import PeftModel
unet_merged = PeftModel.from_pretrained(unet_base, lora_save_path)
unet_merged = unet_merged.merge_and_unload()
# Save the merged model in safetensor format
unet_merged.save_pretrained(
unet_fused_save_path,
safe_serialization=True
)
logger.info(f"Fused UNet saved to {unet_fused_save_path}")
# Save training config
import json
config_save_path = os.path.join(args.output_dir, "training_config.json")
training_config = {
"qwen_model_path": args.qwen_model_path,
"unet_model_path": args.unet_model_path,
"vae_model_path": args.vae_model_path,
"lora_rank": args.lora_rank,
"lora_alpha": args.lora_alpha,
"lora_dropout": args.lora_dropout,
"learning_rate": args.learning_rate,
"train_batch_size": args.train_batch_size,
"num_train_epochs": args.num_train_epochs,
"gradient_accumulation_steps": args.gradient_accumulation_steps,
}
with open(config_save_path, 'w') as f:
json.dump(training_config, f, indent=2)
logger.info(f"Training config saved to {config_save_path}")
logger.info(f"Training completed. All models saved to {args.output_dir}")
logger.info("Saved components:")
logger.info(f" - Adapter: {adapter_save_path}/adapter.safetensors")
logger.info(f" - LoRA weights only: {lora_save_path}/lora_weights.safetensors")
logger.info(f" - UNet with fused LoRA: {unet_fused_save_path}")
logger.info(f" - Training config: {config_save_path}")
# Finish wandb run
if args.report_to in ["wandb", "all"] and accelerator.is_main_process:
wandb.finish()
accelerator.end_training()
if __name__ == "__main__":
main()