| |
| |
| """ |
| SD3 LoRA分布式采样脚本 |
| 使用微调后的LoRA权重,基于JSONL文件中的caption生成图像样本,并保存为npz格式用于评估 |
| """ |
|
|
| import torch |
| import torch.distributed as dist |
| from tqdm import tqdm |
| import os |
| from PIL import Image |
| import numpy as np |
| import math |
| import argparse |
| import sys |
| import json |
| import random |
| from pathlib import Path |
|
|
| from diffusers import ( |
| StableDiffusion3Pipeline, |
| AutoencoderKL, |
| FlowMatchEulerDiscreteScheduler, |
| SD3Transformer2DModel, |
| ) |
| from transformers import CLIPTokenizer, T5TokenizerFast |
| from accelerate import Accelerator |
| from peft import LoraConfig |
| from peft.utils import get_peft_model_state_dict |
|
|
|
|
| def create_npz_from_sample_folder(sample_dir, num_samples): |
| """ |
| 从样本文件夹构建单个.npz文件,保持与sample_ddp_new相同的格式 |
| """ |
| samples = [] |
| actual_files = [] |
| |
| |
| for filename in sorted(os.listdir(sample_dir)): |
| if filename.endswith('.png'): |
| actual_files.append(filename) |
| |
| |
| for i in tqdm(range(min(num_samples, len(actual_files))), desc="Building .npz file from samples"): |
| if i < len(actual_files): |
| sample_path = os.path.join(sample_dir, actual_files[i]) |
| sample_pil = Image.open(sample_path) |
| sample_np = np.asarray(sample_pil).astype(np.uint8) |
| samples.append(sample_np) |
| else: |
| |
| sample_np = np.zeros((512, 512, 3), dtype=np.uint8) |
| samples.append(sample_np) |
| |
| if samples: |
| samples = np.stack(samples) |
| npz_path = f"{sample_dir}.npz" |
| np.savez(npz_path, arr_0=samples) |
| print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") |
| return npz_path |
| else: |
| print("No samples found to create npz file.") |
| return None |
|
|
|
|
| def find_latest_checkpoint(output_dir): |
| """ |
| 查找最新的检查点目录 |
| """ |
| checkpoint_dirs = [] |
| if os.path.exists(output_dir): |
| for item in os.listdir(output_dir): |
| if item.startswith("checkpoint-") and os.path.isdir(os.path.join(output_dir, item)): |
| try: |
| step = int(item.split("-")[1]) |
| checkpoint_dirs.append((step, item)) |
| except (ValueError, IndexError): |
| continue |
| |
| if checkpoint_dirs: |
| |
| checkpoint_dirs.sort(key=lambda x: x[0]) |
| latest_step, latest_dir = checkpoint_dirs[-1] |
| latest_path = os.path.join(output_dir, latest_dir) |
| return latest_path, latest_step |
| return None, None |
|
|
|
|
| def check_lora_weights_exist(lora_path): |
| """ |
| 检查LoRA权重文件是否存在 |
| """ |
| if not lora_path: |
| return False |
| |
| |
| if os.path.isdir(lora_path): |
| |
| weight_file = os.path.join(lora_path, "pytorch_lora_weights.safetensors") |
| if os.path.exists(weight_file): |
| return True |
| |
| for file in os.listdir(lora_path): |
| if file.endswith(".safetensors") and "lora" in file.lower(): |
| return True |
| return False |
| |
| |
| elif os.path.isfile(lora_path): |
| return lora_path.endswith(".safetensors") |
| |
| return False |
|
|
|
|
| def check_full_finetune_checkpoint(checkpoint_path): |
| """ |
| 检查是否是全量微调的checkpoint(包含model.safetensors) |
| """ |
| if not checkpoint_path or not os.path.isdir(checkpoint_path): |
| return False |
| |
| |
| model_file = os.path.join(checkpoint_path, "model.safetensors") |
| return os.path.exists(model_file) |
|
|
|
|
| def load_lora_from_checkpoint(pipeline, checkpoint_path, rank=0): |
| """ |
| 从检查点加载LoRA权重 |
| """ |
| if rank == 0: |
| print(f"Loading LoRA weights from checkpoint: {checkpoint_path}") |
| |
| |
| try: |
| |
| accelerator = Accelerator() |
| |
| |
| transformer_lora_config = LoraConfig( |
| r=64, |
| lora_alpha=64, |
| init_lora_weights="gaussian", |
| target_modules=["attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0"], |
| ) |
| |
| |
| pipeline.transformer.add_adapter(transformer_lora_config) |
| |
| |
| accelerator.load_state(checkpoint_path) |
| |
| if rank == 0: |
| print(f"Successfully loaded LoRA weights from checkpoint {checkpoint_path}") |
| |
| return True |
| |
| except Exception as e: |
| if rank == 0: |
| print(f"Error loading LoRA from checkpoint {checkpoint_path}: {e}") |
| print("Falling back to baseline model without LoRA") |
| return False |
|
|
|
|
| def load_captions_from_jsonl(jsonl_path): |
| """ |
| 从JSONL文件加载caption列表 |
| """ |
| captions = [] |
| try: |
| with open(jsonl_path, 'r', encoding='utf-8') as f: |
| for line_num, line in enumerate(f, 1): |
| line = line.strip() |
| if not line: |
| continue |
| |
| try: |
| data = json.loads(line) |
| |
| caption = None |
| for field in ['caption', 'text', 'prompt', 'description']: |
| if field in data and isinstance(data[field], str): |
| caption = data[field].strip() |
| break |
| |
| if caption: |
| captions.append(caption) |
| else: |
| |
| for value in data.values(): |
| if isinstance(value, str) and value.strip(): |
| captions.append(value.strip()) |
| break |
| |
| except json.JSONDecodeError as e: |
| print(f"Warning: Invalid JSON on line {line_num}: {e}") |
| continue |
| |
| except FileNotFoundError: |
| print(f"Error: JSONL file {jsonl_path} not found") |
| return [] |
| except Exception as e: |
| print(f"Error loading JSONL file {jsonl_path}: {e}") |
| return [] |
| |
| print(f"Loaded {len(captions)} captions from {jsonl_path}") |
| return captions |
|
|
|
|
| def main(args): |
| """ |
| 运行 SD3 LoRA 采样 |
| """ |
| assert torch.cuda.is_available(), "DDP采样需要至少一个GPU" |
| torch.set_grad_enabled(False) |
| |
| |
| dist.init_process_group("nccl") |
| rank = dist.get_rank() |
| device = rank % torch.cuda.device_count() |
| seed = args.global_seed * dist.get_world_size() + rank |
| torch.manual_seed(seed) |
| torch.cuda.set_device(device) |
| print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") |
| |
| |
| captions = [] |
| if args.captions_jsonl: |
| if rank == 0: |
| print(f"Loading captions from {args.captions_jsonl}") |
| captions = load_captions_from_jsonl(args.captions_jsonl) |
| if not captions: |
| if rank == 0: |
| print("Warning: No captions loaded, using default caption") |
| captions = ["a beautiful high quality image"] |
| else: |
| |
| captions = ["a beautiful high quality image"] |
| |
| |
| total_images_needed = len(captions) * args.images_per_caption |
| |
| total_images_needed = min(total_images_needed, args.max_samples) |
| if rank == 0: |
| print(f"Will generate {args.images_per_caption} images for each of {len(captions)} captions") |
| print(f"Total images requested: {len(captions) * args.images_per_caption}") |
| print(f"Max samples limit: {args.max_samples}") |
| print(f"Total images to generate: {total_images_needed}") |
|
|
| |
| if args.mixed_precision == "fp16": |
| dtype = torch.float16 |
| elif args.mixed_precision == "bf16": |
| dtype = torch.bfloat16 |
| else: |
| dtype = torch.float32 |
| |
| |
| is_full_finetune = False |
| if args.lora_path and check_full_finetune_checkpoint(args.lora_path): |
| |
| if rank == 0: |
| print(f"Detected full fine-tuning checkpoint, loading from: {args.lora_path}") |
| try: |
| pipeline = StableDiffusion3Pipeline.from_pretrained( |
| args.lora_path, |
| revision=args.revision, |
| variant=args.variant, |
| torch_dtype=dtype, |
| ) |
| is_full_finetune = True |
| lora_source = os.path.basename(args.lora_path.rstrip('/')) |
| if rank == 0: |
| print("Successfully loaded full fine-tuned model from checkpoint") |
| except Exception as e: |
| if rank == 0: |
| print(f"Failed to load full fine-tuned model: {e}") |
| print("Falling back to baseline model + LoRA loading") |
| is_full_finetune = False |
| |
| |
| if not is_full_finetune: |
| if rank == 0: |
| print(f"Loading SD3 pipeline from {args.pretrained_model_name_or_path}") |
| |
| pipeline = StableDiffusion3Pipeline.from_pretrained( |
| args.pretrained_model_name_or_path, |
| revision=args.revision, |
| variant=args.variant, |
| torch_dtype=dtype, |
| ) |
| |
| |
| lora_loaded = False |
| lora_source = "baseline" if not is_full_finetune else lora_source |
| |
| if not is_full_finetune and args.lora_path: |
| |
| if check_lora_weights_exist(args.lora_path): |
| if rank == 0: |
| print(f"Loading LoRA weights from specified path: {args.lora_path}") |
| try: |
| pipeline.load_lora_weights(args.lora_path) |
| lora_loaded = True |
| lora_source = os.path.basename(args.lora_path.rstrip('/')) |
| if rank == 0: |
| print("Successfully loaded LoRA weights from specified path") |
| except Exception as e: |
| if rank == 0: |
| print(f"Failed to load LoRA from specified path: {e}") |
| else: |
| if rank == 0: |
| print(f"No LoRA weights found at specified path: {args.lora_path}") |
| |
| |
| if not is_full_finetune and not lora_loaded: |
| |
| current_dir = os.getcwd() |
| if check_lora_weights_exist(current_dir): |
| if rank == 0: |
| print(f"Found LoRA weights in current directory: {current_dir}") |
| try: |
| pipeline.load_lora_weights(current_dir) |
| lora_loaded = True |
| lora_source = "current_dir" |
| if rank == 0: |
| print("Successfully loaded LoRA weights from current directory") |
| except Exception as e: |
| if rank == 0: |
| print(f"Failed to load LoRA from current directory: {e}") |
| |
| |
| if not lora_loaded: |
| |
| possible_output_dirs = [ |
| "sd3-lora-finetuned", |
| "sd3-lora-finetuned-last", |
| "output", |
| "checkpoints" |
| ] |
| |
| checkpoint_found = False |
| for output_dir in possible_output_dirs: |
| if os.path.exists(output_dir): |
| |
| if check_lora_weights_exist(output_dir): |
| if rank == 0: |
| print(f"Found LoRA weights in output directory: {output_dir}") |
| try: |
| pipeline.load_lora_weights(output_dir) |
| lora_loaded = True |
| lora_source = output_dir |
| if rank == 0: |
| print(f"Successfully loaded LoRA weights from {output_dir}") |
| break |
| except Exception as e: |
| if rank == 0: |
| print(f"Failed to load LoRA from {output_dir}: {e}") |
| |
| |
| if not lora_loaded: |
| latest_checkpoint, latest_step = find_latest_checkpoint(output_dir) |
| if latest_checkpoint: |
| if rank == 0: |
| print(f"Found latest checkpoint: {latest_checkpoint} (step {latest_step})") |
| |
| |
| if load_lora_from_checkpoint(pipeline, latest_checkpoint, rank): |
| lora_loaded = True |
| lora_source = f"checkpoint-{latest_step}" |
| checkpoint_found = True |
| break |
| |
| if not checkpoint_found and not lora_loaded: |
| if rank == 0: |
| print("No LoRA weights or checkpoints found. Using baseline model.") |
| |
| |
| if args.enable_cpu_offload: |
| if rank == 0: |
| print("Enabling CPU offload to save memory") |
| |
| pipeline.enable_model_cpu_offload() |
| else: |
| |
| pipeline = pipeline.to(device) |
| if rank == 0: |
| print("Enabling memory optimization options") |
| |
| |
| |
| if hasattr(pipeline, 'enable_attention_slicing'): |
| try: |
| pipeline.enable_attention_slicing() |
| if rank == 0: |
| print(" - Attention slicing enabled") |
| except Exception as e: |
| if rank == 0: |
| print(f" - Warning: Failed to enable attention slicing: {e}") |
| else: |
| if rank == 0: |
| print(" - Attention slicing not available for this pipeline") |
| |
| |
| |
| enable_vae_slicing_method = getattr(pipeline, 'enable_vae_slicing', None) |
| if enable_vae_slicing_method is not None and callable(enable_vae_slicing_method): |
| try: |
| enable_vae_slicing_method() |
| if rank == 0: |
| print(" - VAE slicing enabled") |
| except Exception as e: |
| if rank == 0: |
| print(f" - Warning: Failed to enable VAE slicing: {e}") |
| else: |
| if rank == 0: |
| print(" - VAE slicing not available for this pipeline (SD3 may not support this)") |
| |
| |
| pipeline.set_progress_bar_config(disable=True) |
| |
| |
| folder_name = f"batch32-rank64-last-sd3-{lora_source}-guidance-{args.guidance_scale}-steps-{args.num_inference_steps}-size-{args.height}x{args.width}" |
| sample_folder_dir = os.path.join(args.sample_dir, folder_name) |
| |
| if rank == 0: |
| os.makedirs(sample_folder_dir, exist_ok=True) |
| print(f"Saving .png samples at {sample_folder_dir}") |
| |
| caption_file = os.path.join(sample_folder_dir, "captions.txt") |
| if os.path.exists(caption_file): |
| os.remove(caption_file) |
| dist.barrier() |
| |
| |
| n = args.per_proc_batch_size |
| global_batch_size = n * dist.get_world_size() |
| |
| |
| existing_samples = 0 |
| if os.path.exists(sample_folder_dir): |
| existing_samples = len([ |
| name for name in os.listdir(sample_folder_dir) |
| if os.path.isfile(os.path.join(sample_folder_dir, name)) and name.endswith(".png") |
| ]) |
| |
| total_samples = int(math.ceil(total_images_needed / global_batch_size) * global_batch_size) |
| if rank == 0: |
| print(f"Total number of images that will be sampled: {total_samples}") |
| print(f"Existing samples: {existing_samples}") |
| |
| assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size" |
| samples_needed_this_gpu = int(total_samples // dist.get_world_size()) |
| assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size" |
| |
| iterations = int(samples_needed_this_gpu // n) |
| done_iterations = int(int(existing_samples // dist.get_world_size()) // n) |
| |
| pbar = range(done_iterations, iterations) |
| pbar = tqdm(pbar) if rank == 0 else pbar |
| |
| |
| caption_image_pairs = [] |
| for i, caption in enumerate(captions): |
| for j in range(args.images_per_caption): |
| caption_image_pairs.append((caption, i, j)) |
| |
| total_generated = existing_samples |
| |
| |
| for i in pbar: |
| |
| batch_prompts = [] |
| batch_caption_info = [] |
| |
| for j in range(n): |
| global_index = i * global_batch_size + j * dist.get_world_size() + rank |
| if global_index < len(caption_image_pairs): |
| caption, caption_idx, image_idx = caption_image_pairs[global_index] |
| batch_prompts.append(caption) |
| batch_caption_info.append((caption, caption_idx, image_idx)) |
| else: |
| |
| if caption_image_pairs: |
| caption, caption_idx, image_idx = caption_image_pairs[-1] |
| batch_prompts.append(caption) |
| batch_caption_info.append((caption, caption_idx, image_idx)) |
| else: |
| batch_prompts.append("a beautiful high quality image") |
| batch_caption_info.append(("a beautiful high quality image", 0, 0)) |
| |
| |
| device_str = "cuda" if torch.cuda.is_available() else "cpu" |
| with torch.autocast(device_str, dtype=dtype): |
| |
| images = [] |
| for k, prompt in enumerate(batch_prompts): |
| |
| image_seed = seed + i * 10000 + k * 1000 + rank |
| generator = torch.Generator(device=device).manual_seed(image_seed) |
| |
| image = pipeline( |
| prompt=prompt, |
| negative_prompt=args.negative_prompt if args.negative_prompt else None, |
| height=args.height, |
| width=args.width, |
| num_inference_steps=args.num_inference_steps, |
| guidance_scale=args.guidance_scale, |
| generator=generator, |
| num_images_per_prompt=1, |
| ).images[0] |
| images.append(image) |
| |
| |
| if k == len(batch_prompts) - 1: |
| torch.cuda.empty_cache() |
| |
| |
| for j, (image, (caption, caption_idx, image_idx)) in enumerate(zip(images, batch_caption_info)): |
| global_index = i * global_batch_size + j * dist.get_world_size() + rank |
| if global_index < len(caption_image_pairs): |
| |
| filename = f"{global_index:06d}_cap{caption_idx:04d}_img{image_idx:02d}.png" |
| image_path = os.path.join(sample_folder_dir, filename) |
| image.save(image_path) |
| |
| |
| if rank == 0: |
| caption_file = os.path.join(sample_folder_dir, "captions.txt") |
| with open(caption_file, "a", encoding="utf-8") as f: |
| f.write(f"{filename}\t{caption}\n") |
| |
| total_generated += global_batch_size |
| |
| |
| torch.cuda.empty_cache() |
| |
| dist.barrier() |
| |
| |
| dist.barrier() |
| |
| |
| if rank == 0: |
| |
| actual_num_samples = len([name for name in os.listdir(sample_folder_dir) if name.endswith(".png")]) |
| print(f"Actually generated {actual_num_samples} images") |
| |
| npz_samples = min(actual_num_samples, total_images_needed, args.max_samples) |
| create_npz_from_sample_folder(sample_folder_dir, npz_samples) |
| print("Done.") |
| |
| dist.barrier() |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="SD3 LoRA分布式采样脚本") |
| |
| |
| parser.add_argument( |
| "--pretrained_model_name_or_path", |
| type=str, |
| default="stabilityai/stable-diffusion-3-medium-diffusers", |
| help="预训练模型路径或HuggingFace模型ID" |
| ) |
| parser.add_argument( |
| "--lora_path", |
| type=str, |
| default=None, |
| help="LoRA权重文件路径" |
| ) |
| parser.add_argument( |
| "--revision", |
| type=str, |
| default=None, |
| help="模型修订版本" |
| ) |
| parser.add_argument( |
| "--variant", |
| type=str, |
| default=None, |
| help="模型变体,如fp16" |
| ) |
| |
| |
| parser.add_argument( |
| "--num_inference_steps", |
| type=int, |
| default=28, |
| help="推理步数" |
| ) |
| parser.add_argument( |
| "--guidance_scale", |
| type=float, |
| default=7.0, |
| help="引导尺度" |
| ) |
| parser.add_argument( |
| "--height", |
| type=int, |
| default=1024, |
| help="生成图像高度" |
| ) |
| parser.add_argument( |
| "--width", |
| type=int, |
| default=1024, |
| help="生成图像宽度" |
| ) |
| parser.add_argument( |
| "--negative_prompt", |
| type=str, |
| default="", |
| help="负面提示词" |
| ) |
| |
| |
| parser.add_argument( |
| "--per_proc_batch_size", |
| type=int, |
| default=1, |
| help="每个进程的批处理大小" |
| ) |
| parser.add_argument( |
| "--sample_dir", |
| type=str, |
| default="sd3_lora_samples", |
| help="样本保存目录" |
| ) |
| |
| |
| parser.add_argument( |
| "--captions_jsonl", |
| type=str, |
| required=True, |
| help="包含caption列表的JSONL文件路径" |
| ) |
| parser.add_argument( |
| "--images_per_caption", |
| type=int, |
| default=1, |
| help="每个caption生成的图像数量" |
| ) |
| parser.add_argument( |
| "--max_samples", |
| type=int, |
| default=30000, |
| help="最大样本生成数量" |
| ) |
| |
| |
| parser.add_argument( |
| "--global_seed", |
| type=int, |
| default=42, |
| help="全局随机种子" |
| ) |
| parser.add_argument( |
| "--mixed_precision", |
| type=str, |
| default="fp16", |
| choices=["no", "fp16", "bf16"], |
| help="混合精度类型" |
| ) |
| parser.add_argument( |
| "--enable_cpu_offload", |
| action="store_true", |
| help="启用CPU offload以节省显存" |
| ) |
| |
| args = parser.parse_args() |
| main(args) |