| |
| """ |
| 训练 QwenIllustrious 模型 - 结合 Qwen 文本编码器和 SDXL UNet |
| """ |
|
|
| import argparse |
| import logging |
| import math |
| import os |
| import random |
| import sys |
| from pathlib import Path |
| from typing import Dict, List, Tuple |
|
|
| import torch |
| import torch.nn.functional as F |
| import torch.utils.checkpoint |
| import transformers |
| import numpy as np |
| import wandb |
| from accelerate import Accelerator |
| from accelerate.logging import get_logger |
| from accelerate.utils import ProjectConfiguration, set_seed |
| from tqdm.auto import tqdm |
| from transformers import AutoTokenizer |
| from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
| from diffusers.optimization import get_scheduler |
| from diffusers.training_utils import compute_snr |
| from diffusers.utils import check_min_version |
| from peft import LoraConfig, get_peft_model, TaskType |
|
|
| |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| from arch import QwenTextEncoder, QwenEmbeddingAdapter |
| from arch.data_loader import QwenIllustriousDataset, collate_fn |
|
|
| from arch.model_loader import load_qwen_model, load_unet_from_safetensors, load_vae_from_safetensors, create_scheduler |
|
|
| |
| check_min_version("0.35.0.dev0") |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Train QwenIllustrious model") |
| |
| |
| parser.add_argument( |
| "--qwen_model_path", |
| type=str, |
| default="models/Qwen3-Embedding-0.6B", |
| help="Path to Qwen text encoder model" |
| ) |
| parser.add_argument( |
| "--unet_model_path", |
| type=str, |
| default="models/extracted_components/waiNSFWIllustrious_v140_unet.safetensors", |
| help="Path to UNet model" |
| ) |
| parser.add_argument( |
| "--unet_config_path", |
| type=str, |
| default="models/extracted_components/waiNSFWIllustrious_v140_unet_config.json", |
| help="Path to SDXL model config file" |
| ) |
| parser.add_argument( |
| "--vae_model_path", |
| type=str, |
| default="models/extracted_components/waiNSFWIllustrious_v140_vae.safetensors", |
| help="Path to VAE model (if different from SDXL)" |
| ) |
| parser.add_argument( |
| "--vae_config_path", |
| type=str, |
| default="models/extracted_components/waiNSFWIllustrious_v140_vae_config.json", |
| help="Path to VAE config file" |
| ) |
| |
| |
| parser.add_argument( |
| "--dataset_path", |
| type=str, |
| default='illustrious_generated', |
| help="Path to illustrious_generated dataset" |
| ) |
| parser.add_argument( |
| "--no_precompute_embeddings", |
| action="store_false", |
| dest="precompute_embeddings", |
| help="Disable precomputing and caching Qwen embeddings and VAE latents" |
| ) |
| parser.set_defaults(precompute_embeddings=True) |
|
|
| parser.add_argument( |
| "--cache_dir", |
| type=str, |
| default="./illustrious_generated/cache", |
| help="Directory to store precomputed embeddings" |
| ) |
| |
| |
| parser.add_argument( |
| "--output_dir", |
| type=str, |
| default="./qwen_illustrious_output", |
| help="Output directory for trained model" |
| ) |
| parser.add_argument( |
| "--train_batch_size", |
| type=int, |
| default=1, |
| help="Batch size for training" |
| ) |
| parser.add_argument( |
| "--num_train_epochs", |
| type=int, |
| default=10, |
| help="Number of training epochs" |
| ) |
| parser.add_argument( |
| "--learning_rate", |
| type=float, |
| default=1e-4, |
| help="Learning rate" |
| ) |
| parser.add_argument( |
| "--max_train_steps", |
| type=int, |
| default=None, |
| help="Maximum number of training steps" |
| ) |
| parser.add_argument( |
| "--gradient_accumulation_steps", |
| type=int, |
| default=1, |
| help="Number of gradient accumulation steps" |
| ) |
| parser.add_argument( |
| "--gradient_checkpointing", |
| action="store_true", |
| help="Enable gradient checkpointing" |
| ) |
| parser.add_argument( |
| "--mixed_precision", |
| type=str, |
| default="fp16", |
| choices=["no", "fp16", "bf16"], |
| help="Mixed precision training" |
| ) |
| |
| |
| parser.add_argument( |
| "--lora_rank", |
| type=int, |
| default=64, |
| help="LoRA rank for SDXL UNet cross attention" |
| ) |
| parser.add_argument( |
| "--lora_alpha", |
| type=int, |
| default=64, |
| help="LoRA alpha" |
| ) |
| parser.add_argument( |
| "--lora_dropout", |
| type=float, |
| default=0.1, |
| help="LoRA dropout" |
| ) |
| |
| |
| parser.add_argument( |
| "--seed", |
| type=int, |
| default=42, |
| help="Random seed" |
| ) |
| parser.add_argument( |
| "--logging_dir", |
| type=str, |
| default="logs", |
| help="Logging directory" |
| ) |
| parser.add_argument( |
| "--report_to", |
| type=str, |
| default="wandb", |
| help="Logging service (tensorboard, wandb, or all)" |
| ) |
| parser.add_argument( |
| "--wandb_project", |
| type=str, |
| default="qwen-illustrious", |
| help="Wandb project name" |
| ) |
| parser.add_argument( |
| "--wandb_run_name", |
| type=str, |
| default=None, |
| help="Wandb run name (optional)" |
| ) |
| parser.add_argument( |
| "--checkpointing_steps", |
| type=int, |
| default=25000, |
| help="Save checkpoint every N steps" |
| ) |
| parser.add_argument( |
| "--resume_from_checkpoint", |
| type=str, |
| default=None, |
| help="Path to checkpoint to resume from" |
| ) |
| parser.add_argument( |
| "--validation_epochs", |
| type=int, |
| default=1, |
| help="Run validation every N epochs" |
| ) |
| parser.add_argument( |
| "--validation_prompts", |
| type=str, |
| nargs="+", |
| default=[ |
| "A beautiful anime girl in a garden", |
| "Two characters having a conversation", |
| "A magical fantasy scene" |
| ], |
| help="Validation prompts" |
| ) |
| |
| return parser.parse_args() |
|
|
|
|
| def setup_models(args, accelerator): |
| """Setup and configure all models""" |
| logger.info("Loading models...") |
| |
| |
| |
| qwen_text_encoder = QwenTextEncoder( |
| model_path=args.qwen_model_path, |
| device='cuda' if torch.cuda.is_available() else 'cpu', |
| max_length=512, |
| freeze_encoder=True |
| ) |
| |
| |
| vae = load_vae_from_safetensors(args.vae_model_path, args.vae_config_path) |
|
|
| unet = load_unet_from_safetensors(args.unet_model_path, args.unet_config_path) |
| |
| |
| noise_scheduler = create_scheduler() |
| |
| |
| adapter = QwenEmbeddingAdapter() |
| |
| |
| logger.info(f"Setting up LoRA with rank={args.lora_rank}, alpha={args.lora_alpha}") |
| |
| |
| target_modules = [] |
| for name, module in unet.named_modules(): |
| if "attn2" in name and ("to_k" in name or "to_v" in name): |
| target_modules.append(name) |
| |
| if not target_modules: |
| logger.warning("No cross attention to_k/to_v modules found. Using default modules.") |
| target_modules = ["to_k", "to_v"] |
| |
| logger.info(f"Applying LoRA to modules: {target_modules}") |
| |
| lora_config = LoraConfig( |
| r=args.lora_rank, |
| lora_alpha=args.lora_alpha, |
| target_modules=target_modules, |
| lora_dropout=args.lora_dropout, |
| bias="none", |
| ) |
| |
| |
| unet = get_peft_model(unet, lora_config) |
| |
| |
| vae.requires_grad_(False) |
| qwen_text_encoder.requires_grad_(False) |
| unet.requires_grad_(False) |
| |
| |
| adapter.requires_grad_(True) |
| for name, param in unet.named_parameters(): |
| if "lora" in name: |
| param.requires_grad_(True) |
| |
| |
| total_params = sum(p.numel() for p in unet.parameters()) |
| trainable_params = sum(p.numel() for p in unet.parameters() if p.requires_grad) |
| adapter_params = sum(p.numel() for p in adapter.parameters()) |
| |
| logger.info(f"UNet total parameters: {total_params:,}") |
| logger.info(f"UNet trainable parameters: {trainable_params:,}") |
| logger.info(f"Adapter parameters: {adapter_params:,}") |
| logger.info(f"Total trainable parameters: {trainable_params + adapter_params:,}") |
| |
| return qwen_text_encoder, unet, vae, noise_scheduler, adapter |
|
|
|
|
| def setup_dataset(args, qwen_text_encoder, vae, accelerator): |
| """Setup dataset with optional precomputation""" |
| logger.info("Setting up dataset...") |
| |
| dataset = QwenIllustriousDataset( |
| dataset_path=args.dataset_path, |
| qwen_text_encoder=qwen_text_encoder if args.precompute_embeddings else None, |
| vae=vae if args.precompute_embeddings else None, |
| cache_dir=args.cache_dir if args.precompute_embeddings else None, |
| precompute_embeddings=args.precompute_embeddings |
| ) |
| |
| if args.precompute_embeddings: |
| logger.info("Precomputing embeddings...") |
| dataset.precompute_all(accelerator.device) |
| |
| dataloader = torch.utils.data.DataLoader( |
| dataset, |
| batch_size=args.train_batch_size, |
| shuffle=True, |
| num_workers=4, |
| pin_memory=True, |
| collate_fn=collate_fn |
| ) |
| |
| return dataset, dataloader |
|
|
|
|
| def training_step(batch, unet, adapter, noise_scheduler, vae, qwen_text_encoder, accelerator, args): |
| """Single training step""" |
| |
| |
| if args.precompute_embeddings: |
| latents = batch["latents"].to(accelerator.device) |
| |
| qwen_text_embeddings = batch["text_embeddings"].to(accelerator.device) |
| qwen_pooled_embeddings = batch["pooled_embeddings"].to(accelerator.device) |
| |
| |
| text_embeddings = adapter.forward_text_embeddings(qwen_text_embeddings) |
| pooled_embeddings = adapter.forward_pooled_embeddings(qwen_pooled_embeddings) |
| else: |
| images = batch["images"].to(accelerator.device) |
| prompts = batch["prompts"] |
| |
| |
| with torch.no_grad(): |
| latents = vae.encode(images).latent_dist.sample() |
| latents = latents * vae.config.scaling_factor |
| |
| |
| with torch.no_grad(): |
| qwen_embeddings = qwen_text_encoder.encode_prompts(prompts, do_classifier_free_guidance=False) |
| |
| |
| text_embeddings = adapter.forward_text_embeddings(qwen_embeddings[0]) |
| pooled_embeddings = adapter.forward_pooled_embeddings(qwen_embeddings[1]) |
|
|
| |
| noise = torch.randn_like(latents) |
| bsz = latents.shape[0] |
| timesteps = torch.randint( |
| 0, noise_scheduler.config.num_train_timesteps, (bsz,), |
| device=latents.device, dtype=torch.long |
| ) |
| |
| |
| noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| |
| |
| encoder_hidden_states = text_embeddings |
| |
| |
| add_time_ids = torch.zeros((bsz, 6), device=latents.device) |
| added_cond_kwargs = { |
| "text_embeds": pooled_embeddings, |
| "time_ids": add_time_ids |
| } |
| |
| |
| model_pred = unet( |
| noisy_latents, |
| timesteps, |
| encoder_hidden_states, |
| added_cond_kwargs=added_cond_kwargs, |
| return_dict=False |
| )[0] |
| |
| |
| if noise_scheduler.config.prediction_type == "epsilon": |
| target = noise |
| elif noise_scheduler.config.prediction_type == "v_prediction": |
| target = noise_scheduler.get_velocity(latents, noise, timesteps) |
| else: |
| raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
| |
| loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| |
| return loss |
|
|
|
|
| def validate_model(args, qwen_text_encoder, unet, adapter, vae, accelerator, epoch): |
| """Run validation""" |
| logger.info(f"Running validation at epoch {epoch}") |
| |
| |
| |
| logger.info("Validation completed") |
| |
| return {} |
|
|
|
|
| def main(): |
| args = parse_args() |
| |
| |
| if args.report_to in ["wandb", "all"]: |
| wandb.init( |
| project=args.wandb_project, |
| name=args.wandb_run_name, |
| config=vars(args) |
| ) |
| |
| |
| logging_dir = Path(args.output_dir, args.logging_dir) |
| accelerator_project_config = ProjectConfiguration( |
| project_dir=args.output_dir, |
| logging_dir=logging_dir |
| ) |
| |
| accelerator = Accelerator( |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| mixed_precision=args.mixed_precision, |
| log_with=args.report_to, |
| project_config=accelerator_project_config, |
| ) |
| |
| |
| logging.basicConfig( |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| datefmt="%m/%d/%Y %H:%M:%S", |
| level=logging.INFO, |
| ) |
| logger.info(accelerator.state, main_process_only=False) |
| |
| |
| if args.seed is not None: |
| set_seed(args.seed) |
| |
| |
| if accelerator.is_main_process: |
| os.makedirs(args.output_dir, exist_ok=True) |
| if args.precompute_embeddings: |
| os.makedirs(args.cache_dir, exist_ok=True) |
| |
| |
| qwen_text_encoder, unet, vae, noise_scheduler, adapter = setup_models(args, accelerator) |
| |
| |
| dataset, dataloader = setup_dataset(args, qwen_text_encoder, vae, accelerator) |
| |
| |
| trainable_params = list(adapter.parameters()) |
| for param in unet.parameters(): |
| if param.requires_grad: |
| trainable_params.append(param) |
| |
| optimizer = torch.optim.AdamW( |
| trainable_params, |
| lr=args.learning_rate, |
| betas=(0.9, 0.999), |
| weight_decay=0.01, |
| eps=1e-8, |
| ) |
| |
| |
| num_update_steps_per_epoch = math.ceil(len(dataloader) / args.gradient_accumulation_steps) |
| if args.max_train_steps is None: |
| args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| |
| |
| lr_scheduler = get_scheduler( |
| "cosine", |
| optimizer=optimizer, |
| num_warmup_steps=500, |
| num_training_steps=args.max_train_steps, |
| ) |
| |
| |
| unet, adapter, optimizer, dataloader, lr_scheduler = accelerator.prepare( |
| unet, adapter, optimizer, dataloader, lr_scheduler |
| ) |
| |
| |
| qwen_text_encoder.to('cpu') |
| vae.to('cpu') |
| |
| |
| if accelerator.is_main_process: |
| tracker_config = vars(args) |
| accelerator.init_trackers(args.wandb_project, config=tracker_config) |
| |
| |
| logger.info("***** Running training *****") |
| logger.info(f" Num examples = {len(dataset)}") |
| logger.info(f" Num Epochs = {args.num_train_epochs}") |
| logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") |
| logger.info(f" Total train batch size = {args.train_batch_size * accelerator.num_processes}") |
| logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") |
| logger.info(f" Total optimization steps = {args.max_train_steps}") |
| |
| global_step = 0 |
| first_epoch = 0 |
| |
| |
| if args.resume_from_checkpoint: |
| if args.resume_from_checkpoint != "latest": |
| path = os.path.basename(args.resume_from_checkpoint) |
| else: |
| |
| dirs = os.listdir(args.output_dir) |
| dirs = [d for d in dirs if d.startswith("checkpoint")] |
| dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) |
| path = dirs[-1] if len(dirs) > 0 else None |
| |
| if path is None: |
| accelerator.print(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting new training.") |
| else: |
| accelerator.print(f"Resuming from checkpoint {path}") |
| accelerator.load_state(os.path.join(args.output_dir, path)) |
| global_step = int(path.split("-")[1]) |
| first_epoch = global_step // num_update_steps_per_epoch |
| |
| progress_bar = tqdm( |
| range(0, args.max_train_steps), |
| initial=global_step, |
| desc="Steps", |
| disable=not accelerator.is_local_main_process, |
| ) |
| |
| for epoch in range(first_epoch, args.num_train_epochs): |
| unet.train() |
| adapter.train() |
| train_loss = 0.0 |
| epoch_loss = 0.0 |
| num_batches = 0 |
| |
| for step, batch in enumerate(dataloader): |
| with accelerator.accumulate(unet): |
| loss = training_step(batch, unet, adapter, noise_scheduler, vae, qwen_text_encoder, accelerator, args) |
| |
| |
| accelerator.backward(loss) |
| |
| if accelerator.sync_gradients: |
| accelerator.clip_grad_norm_(trainable_params, 1.0) |
| |
| optimizer.step() |
| lr_scheduler.step() |
| optimizer.zero_grad() |
| |
| |
| if accelerator.sync_gradients: |
| progress_bar.update(1) |
| global_step += 1 |
| |
| |
| avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() |
| train_loss += avg_loss.item() / args.gradient_accumulation_steps |
| epoch_loss += avg_loss.item() |
| num_batches += 1 |
| |
| |
| log_dict = { |
| "train/step_loss": avg_loss.item(), |
| "train/learning_rate": lr_scheduler.get_last_lr()[0], |
| "train/epoch": epoch, |
| "train/global_step": global_step |
| } |
| accelerator.log(log_dict, step=global_step) |
| |
| |
| if args.report_to in ["wandb", "all"] and accelerator.is_main_process: |
| wandb.log(log_dict, step=global_step) |
| |
| train_loss = 0.0 |
| |
| |
| if global_step % args.checkpointing_steps == 0: |
| if accelerator.is_main_process: |
| save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") |
| accelerator.save_state(save_path) |
| logger.info(f"Saved state to {save_path}") |
| |
| logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} |
| progress_bar.set_postfix(**logs) |
| |
| if global_step >= args.max_train_steps: |
| break |
| |
| |
| if num_batches > 0 and accelerator.is_main_process: |
| avg_epoch_loss = epoch_loss / num_batches |
| epoch_log_dict = { |
| "train/epoch_loss": avg_epoch_loss, |
| "train/epoch_num": epoch |
| } |
| accelerator.log(epoch_log_dict, step=global_step) |
| if args.report_to in ["wandb", "all"]: |
| wandb.log(epoch_log_dict, step=global_step) |
| logger.info(f"Epoch {epoch} - Average Loss: {avg_epoch_loss:.4f}") |
| |
| |
| if epoch % args.validation_epochs == 0: |
| validation_metrics = validate_model( |
| args, qwen_text_encoder, unet, adapter, vae, accelerator, epoch |
| ) |
| if validation_metrics: |
| accelerator.log(validation_metrics, step=global_step) |
| if args.report_to in ["wandb", "all"] and accelerator.is_main_process: |
| wandb.log(validation_metrics, step=global_step) |
| |
| |
| accelerator.wait_for_everyone() |
| if accelerator.is_main_process: |
| from safetensors.torch import save_file |
| from peft import get_peft_model_state_dict |
| |
| logger.info("Saving trained models...") |
| |
| |
| adapter_save_path = os.path.join(args.output_dir, "adapter") |
| os.makedirs(adapter_save_path, exist_ok=True) |
| adapter_state_dict = adapter.state_dict() |
| save_file(adapter_state_dict, os.path.join(adapter_save_path, "adapter.safetensors")) |
| logger.info(f"Adapter saved to {adapter_save_path}/adapter.safetensors") |
| |
| |
| lora_save_path = os.path.join(args.output_dir, "lora_weights") |
| os.makedirs(lora_save_path, exist_ok=True) |
| |
| |
| lora_state_dict = get_peft_model_state_dict(unet) |
| save_file(lora_state_dict, os.path.join(lora_save_path, "lora_weights.safetensors")) |
| logger.info(f"LoRA weights saved to {lora_save_path}/lora_weights.safetensors") |
| |
| |
| lora_config_path = os.path.join(lora_save_path, "adapter_config.json") |
| unet.peft_config['default'].save_pretrained(lora_save_path) |
| logger.info(f"LoRA config saved to {lora_save_path}/adapter_config.json") |
| |
| |
| logger.info("Fusing LoRA weights into UNet...") |
| unet_fused_save_path = os.path.join(args.output_dir, "unet_fused") |
| os.makedirs(unet_fused_save_path, exist_ok=True) |
| |
| |
| from diffusers import UNet2DConditionModel |
| unet_base = unet |
| |
| |
| from peft import PeftModel |
| unet_merged = PeftModel.from_pretrained(unet_base, lora_save_path) |
| unet_merged = unet_merged.merge_and_unload() |
| |
| |
| unet_merged.save_pretrained( |
| unet_fused_save_path, |
| safe_serialization=True |
| ) |
| logger.info(f"Fused UNet saved to {unet_fused_save_path}") |
| |
| |
| import json |
| config_save_path = os.path.join(args.output_dir, "training_config.json") |
| training_config = { |
| "qwen_model_path": args.qwen_model_path, |
| "unet_model_path": args.unet_model_path, |
| "vae_model_path": args.vae_model_path, |
| "lora_rank": args.lora_rank, |
| "lora_alpha": args.lora_alpha, |
| "lora_dropout": args.lora_dropout, |
| "learning_rate": args.learning_rate, |
| "train_batch_size": args.train_batch_size, |
| "num_train_epochs": args.num_train_epochs, |
| "gradient_accumulation_steps": args.gradient_accumulation_steps, |
| } |
| with open(config_save_path, 'w') as f: |
| json.dump(training_config, f, indent=2) |
| logger.info(f"Training config saved to {config_save_path}") |
| |
| logger.info(f"Training completed. All models saved to {args.output_dir}") |
| logger.info("Saved components:") |
| logger.info(f" - Adapter: {adapter_save_path}/adapter.safetensors") |
| logger.info(f" - LoRA weights only: {lora_save_path}/lora_weights.safetensors") |
| logger.info(f" - UNet with fused LoRA: {unet_fused_save_path}") |
| logger.info(f" - Training config: {config_save_path}") |
| |
| |
| if args.report_to in ["wandb", "all"] and accelerator.is_main_process: |
| wandb.finish() |
| |
| accelerator.end_training() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|