| |
| |
| """ |
| SD3 LoRA分布式采样脚本 - 从accelerator checkpoint加载LoRA权重 |
| 使用微调后的LoRA权重,基于JSONL文件中的caption生成图像样本,并保存为npz格式用于评估 |
| """ |
|
|
| import torch |
| import torch.distributed as dist |
| from tqdm import tqdm |
| import os |
| from PIL import Image |
| import numpy as np |
| import math |
| import argparse |
| import sys |
| import json |
| import random |
| from pathlib import Path |
|
|
| from diffusers import ( |
| StableDiffusion3Pipeline, |
| AutoencoderKL, |
| FlowMatchEulerDiscreteScheduler, |
| SD3Transformer2DModel, |
| ) |
| from transformers import CLIPTokenizer, T5TokenizerFast |
| from accelerate import Accelerator |
| from peft import LoraConfig, PeftModel |
| from peft.utils import get_peft_model_state_dict |
| from safetensors.torch import load_file, save_file |
|
|
|
|
| def create_npz_from_sample_folder(sample_dir, num_samples): |
| """ |
| 从样本文件夹构建单个.npz文件,保持与sample_ddp_new相同的格式 |
| """ |
| samples = [] |
| actual_files = [] |
| |
| |
| for filename in sorted(os.listdir(sample_dir)): |
| if filename.endswith('.png'): |
| actual_files.append(filename) |
| |
| |
| for i in tqdm(range(min(num_samples, len(actual_files))), desc="Building .npz file from samples"): |
| if i < len(actual_files): |
| sample_path = os.path.join(sample_dir, actual_files[i]) |
| sample_pil = Image.open(sample_path) |
| sample_np = np.asarray(sample_pil).astype(np.uint8) |
| samples.append(sample_np) |
| else: |
| |
| sample_np = np.zeros((512, 512, 3), dtype=np.uint8) |
| samples.append(sample_np) |
| |
| if samples: |
| samples = np.stack(samples) |
| npz_path = f"{sample_dir}.npz" |
| np.savez(npz_path, arr_0=samples) |
| print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") |
| return npz_path |
| else: |
| print("No samples found to create npz file.") |
| return None |
|
|
|
|
| def extract_lora_from_checkpoint(checkpoint_path, output_lora_path, rank=64, rank0_only=True): |
| """ |
| 从accelerator checkpoint中提取LoRA权重并保存为标准格式 |
| |
| Args: |
| checkpoint_path: checkpoint目录路径 |
| output_lora_path: 输出LoRA权重保存路径 |
| rank: LoRA rank |
| rank0_only: 是否只在rank 0上执行 |
| """ |
| model_file = os.path.join(checkpoint_path, "model.safetensors") |
| if not os.path.exists(model_file): |
| if rank0_only: |
| print(f"Model file not found: {model_file}") |
| return False |
| |
| try: |
| |
| state_dict = load_file(model_file) |
| |
| if rank0_only: |
| print(f"Loaded checkpoint with {len(state_dict)} keys") |
| |
| |
| |
| |
| lora_state_dict = {} |
| |
| |
| lora_keys = [] |
| for key in state_dict.keys(): |
| |
| if 'lora_A' in key or 'lora_B' in key or 'lora_embedding' in key: |
| lora_keys.append(key) |
| |
| if rank0_only: |
| print(f"Found {len(lora_keys)} LoRA keys") |
| if lora_keys: |
| print(f"Sample LoRA keys: {lora_keys[:5]}") |
| |
| if not lora_keys: |
| if rank0_only: |
| print("Warning: No LoRA keys found in checkpoint. Trying alternative extraction method...") |
| |
| |
| |
| |
| |
| |
| |
| transformer_keys = [k for k in state_dict.keys() if 'transformer' in k.lower() and 'lora' not in k.lower()] |
| if transformer_keys: |
| if rank0_only: |
| print(f"Found {len(transformer_keys)} transformer keys (full fine-tuning checkpoint)") |
| print("This checkpoint appears to contain full model weights, not LoRA weights.") |
| print("You may need to use a different loading method.") |
| return False |
| |
| |
| for key in lora_keys: |
| |
| new_key = key |
| if new_key.startswith("model."): |
| new_key = new_key[6:] |
| |
| |
| |
| lora_state_dict[new_key] = state_dict[key] |
| |
| if not lora_state_dict: |
| if rank0_only: |
| print("Error: Failed to extract LoRA weights from checkpoint") |
| return False |
| |
| |
| if rank0_only: |
| os.makedirs(output_lora_path, exist_ok=True) |
| lora_file = os.path.join(output_lora_path, "pytorch_lora_weights.safetensors") |
| save_file(lora_state_dict, lora_file) |
| print(f"Saved LoRA weights to {lora_file} ({len(lora_state_dict)} keys)") |
| |
| return True |
| |
| except Exception as e: |
| if rank0_only: |
| print(f"Error extracting LoRA from checkpoint: {e}") |
| import traceback |
| traceback.print_exc() |
| return False |
|
|
|
|
| def load_lora_from_checkpoint_direct(pipeline, checkpoint_path, rank=64, rank0_print=True): |
| """ |
| 直接从checkpoint加载LoRA权重到pipeline |
| |
| 这个方法尝试直接从checkpoint中加载LoRA权重,而不需要先提取 |
| """ |
| model_file = os.path.join(checkpoint_path, "model.safetensors") |
| if not os.path.exists(model_file): |
| if rank0_print: |
| print(f"Model file not found: {model_file}") |
| return False |
| |
| try: |
| |
| state_dict = load_file(model_file) |
| |
| if rank0_print: |
| print(f"Loaded checkpoint with {len(state_dict)} keys") |
| |
| sample_keys = list(state_dict.keys())[:10] |
| print(f"Sample keys: {sample_keys}") |
| |
| |
| lora_keys = [k for k in state_dict.keys() if 'lora_A' in k or 'lora_B' in k or 'lora_embedding' in k] |
| |
| if not lora_keys: |
| if rank0_print: |
| print("No LoRA keys found in checkpoint.") |
| print("This checkpoint might contain merged weights or use a different format.") |
| print("Checking checkpoint structure...") |
| |
| |
| transformer_keys = [k for k in state_dict.keys() if 'transformer' in k.lower() and 'lora' not in k.lower()] |
| if transformer_keys: |
| if rank0_print: |
| print(f"Found {len(transformer_keys)} transformer keys") |
| print("This appears to be a full fine-tuning checkpoint with merged weights.") |
| print("Attempting to use Accelerator to load the checkpoint...") |
| |
| |
| try: |
| |
| transformer_lora_config = LoraConfig( |
| r=rank, |
| lora_alpha=rank, |
| init_lora_weights="gaussian", |
| target_modules=["attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0"], |
| ) |
| |
| |
| pipeline.transformer.add_adapter(transformer_lora_config) |
| |
| |
| accelerator = Accelerator() |
| |
| transformer_prepared = accelerator.prepare(pipeline.transformer) |
| |
| accelerator.load_state(checkpoint_path) |
| |
| pipeline.transformer = accelerator.unwrap_model(transformer_prepared) |
| |
| if rank0_print: |
| print("Successfully loaded checkpoint using Accelerator") |
| return True |
| except Exception as e: |
| if rank0_print: |
| print(f"Failed to load using Accelerator: {e}") |
| return False |
| else: |
| if rank0_print: |
| print("Could not identify checkpoint format. Please check the checkpoint structure.") |
| return False |
| |
| if rank0_print: |
| print(f"Found {len(lora_keys)} LoRA keys") |
| print(f"Sample LoRA keys: {lora_keys[:5]}") |
| |
| |
| transformer_lora_config = LoraConfig( |
| r=rank, |
| lora_alpha=rank, |
| init_lora_weights="gaussian", |
| target_modules=["attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0"], |
| ) |
| |
| |
| pipeline.transformer.add_adapter(transformer_lora_config) |
| |
| if rank0_print: |
| print("LoRA adapter configured") |
| |
| |
| lora_state_dict = {} |
| for key in lora_keys: |
| |
| new_key = key |
| |
| prefixes_to_remove = ["model.", "module.", "transformer."] |
| for prefix in prefixes_to_remove: |
| if new_key.startswith(prefix): |
| new_key = new_key[len(prefix):] |
| break |
| |
| |
| |
| |
| lora_state_dict[new_key] = state_dict[key] |
| |
| if rank0_print: |
| print(f"Extracted {len(lora_state_dict)} LoRA weights") |
| print(f"Sample extracted keys: {list(lora_state_dict.keys())[:5]}") |
| |
| |
| |
| missing_keys, unexpected_keys = pipeline.transformer.load_state_dict(lora_state_dict, strict=False) |
| |
| if rank0_print: |
| if missing_keys: |
| print(f"Missing keys: {len(missing_keys)}") |
| if len(missing_keys) <= 10: |
| for k in missing_keys: |
| print(f" - {k}") |
| else: |
| print(f" (showing first 10 of {len(missing_keys)} missing keys)") |
| for k in list(missing_keys)[:10]: |
| print(f" - {k}") |
| if unexpected_keys: |
| print(f"Unexpected keys: {len(unexpected_keys)}") |
| if len(unexpected_keys) <= 10: |
| for k in unexpected_keys: |
| print(f" - {k}") |
| else: |
| print(f" (showing first 10 of {len(unexpected_keys)} unexpected keys)") |
| for k in list(unexpected_keys)[:10]: |
| print(f" - {k}") |
| |
| |
| if hasattr(pipeline.transformer, 'peft_config'): |
| if rank0_print: |
| print(f"LoRA config found: {list(pipeline.transformer.peft_config.keys())}") |
| else: |
| if rank0_print: |
| print("Warning: No peft_config found after loading LoRA") |
| |
| |
| if rank0_print: |
| |
| has_lora_weights = False |
| for name, param in pipeline.transformer.named_parameters(): |
| if 'lora' in name.lower() and param.requires_grad: |
| if param.abs().max().item() > 1e-6: |
| has_lora_weights = True |
| if rank0_print: |
| print(f"Verified LoRA weights loaded (found non-zero LoRA param: {name})") |
| break |
| |
| if not has_lora_weights: |
| print("Warning: LoRA weights may not have been loaded correctly (all LoRA params are zero or not found)") |
| |
| if rank0_print: |
| print("LoRA weights loaded successfully") |
| |
| return True |
| |
| except Exception as e: |
| if rank0_print: |
| print(f"Error loading LoRA from checkpoint: {e}") |
| import traceback |
| traceback.print_exc() |
| return False |
|
|
|
|
| def load_captions_from_jsonl(jsonl_path): |
| """ |
| 从JSONL文件加载caption列表 |
| """ |
| captions = [] |
| try: |
| with open(jsonl_path, 'r', encoding='utf-8') as f: |
| for line_num, line in enumerate(f, 1): |
| line = line.strip() |
| if not line: |
| continue |
| |
| try: |
| data = json.loads(line) |
| |
| caption = None |
| for field in ['caption', 'text', 'prompt', 'description']: |
| if field in data and isinstance(data[field], str): |
| caption = data[field].strip() |
| break |
| |
| if caption: |
| captions.append(caption) |
| else: |
| |
| for value in data.values(): |
| if isinstance(value, str) and value.strip(): |
| captions.append(value.strip()) |
| break |
| |
| except json.JSONDecodeError as e: |
| print(f"Warning: Invalid JSON on line {line_num}: {e}") |
| continue |
| |
| except FileNotFoundError: |
| print(f"Error: JSONL file {jsonl_path} not found") |
| return [] |
| except Exception as e: |
| print(f"Error loading JSONL file {jsonl_path}: {e}") |
| return [] |
| |
| print(f"Loaded {len(captions)} captions from {jsonl_path}") |
| return captions |
|
|
|
|
| def main(args): |
| """ |
| 运行 SD3 LoRA 采样 |
| """ |
| assert torch.cuda.is_available(), "DDP采样需要至少一个GPU" |
| torch.set_grad_enabled(False) |
| |
| |
| dist.init_process_group("nccl") |
| rank = dist.get_rank() |
| world_size = dist.get_world_size() |
| device = torch.device(f"cuda:{rank}") |
| seed = args.global_seed * world_size + rank |
| torch.manual_seed(seed) |
| torch.cuda.set_device(device) |
| print(f"Starting rank={rank}, device={device}, seed={seed}, world_size={world_size}, visible_devices={torch.cuda.device_count()}.") |
| |
| |
| captions = [] |
| if args.captions_jsonl: |
| if rank == 0: |
| print(f"Loading captions from {args.captions_jsonl}") |
| captions = load_captions_from_jsonl(args.captions_jsonl) |
| if not captions: |
| if rank == 0: |
| print("Warning: No captions loaded, using default caption") |
| captions = ["a beautiful high quality image"] |
| else: |
| |
| captions = ["a beautiful high quality image"] |
| |
| |
| total_images_needed = len(captions) * args.images_per_caption |
| |
| total_images_needed = min(total_images_needed, args.max_samples) |
| if rank == 0: |
| print(f"Will generate {args.images_per_caption} images for each of {len(captions)} captions") |
| print(f"Total images requested: {len(captions) * args.images_per_caption}") |
| print(f"Max samples limit: {args.max_samples}") |
| print(f"Total images to generate: {total_images_needed}") |
|
|
| |
| if args.mixed_precision == "fp16": |
| dtype = torch.float16 |
| elif args.mixed_precision == "bf16": |
| dtype = torch.bfloat16 |
| else: |
| dtype = torch.float32 |
| |
| |
| if rank == 0: |
| print(f"Loading SD3 pipeline from {args.pretrained_model_name_or_path}") |
| |
| pipeline = StableDiffusion3Pipeline.from_pretrained( |
| args.pretrained_model_name_or_path, |
| revision=args.revision, |
| variant=args.variant, |
| torch_dtype=dtype, |
| ) |
| |
| |
| lora_loaded = False |
| lora_source = "baseline" |
| |
| if args.lora_checkpoint_path: |
| if rank == 0: |
| print(f"Loading LoRA weights from checkpoint: {args.lora_checkpoint_path}") |
| |
| |
| lora_loaded = load_lora_from_checkpoint_direct( |
| pipeline, |
| args.lora_checkpoint_path, |
| rank=args.lora_rank, |
| rank0_print=(rank == 0) |
| ) |
| |
| if lora_loaded: |
| lora_source = os.path.basename(args.lora_checkpoint_path.rstrip('/')) |
| if rank == 0: |
| print("Successfully loaded LoRA weights from checkpoint") |
| else: |
| if rank == 0: |
| print("Failed to load LoRA weights directly from checkpoint") |
| print("Trying alternative method: extracting LoRA weights first...") |
| |
| |
| temp_lora_path = os.path.join(args.lora_checkpoint_path, "extracted_lora") |
| if rank == 0: |
| extract_success = extract_lora_from_checkpoint( |
| args.lora_checkpoint_path, |
| temp_lora_path, |
| rank=args.lora_rank, |
| rank0_only=True |
| ) |
| else: |
| extract_success = False |
| |
| dist.barrier() |
| |
| if extract_success and os.path.exists(os.path.join(temp_lora_path, "pytorch_lora_weights.safetensors")): |
| if rank == 0: |
| print(f"Loading extracted LoRA weights from {temp_lora_path}") |
| try: |
| pipeline.load_lora_weights(temp_lora_path) |
| lora_loaded = True |
| lora_source = f"{os.path.basename(args.lora_checkpoint_path.rstrip('/'))}_extracted" |
| if rank == 0: |
| print("Successfully loaded extracted LoRA weights") |
| except Exception as e: |
| if rank == 0: |
| print(f"Failed to load extracted LoRA weights: {e}") |
| |
| if not lora_loaded: |
| if rank == 0: |
| print("Warning: No LoRA weights loaded. Using baseline model.") |
| |
| |
| |
| |
| if args.enable_cpu_offload and world_size > 1: |
| if rank == 0: |
| print(f"Warning: CPU offload is disabled in multi-GPU mode (world_size={world_size})") |
| print("Using device-specific placement instead") |
| args.enable_cpu_offload = False |
| |
| if args.enable_cpu_offload: |
| if rank == 0: |
| print("Enabling CPU offload to save memory (single GPU mode)") |
| |
| pipeline.enable_model_cpu_offload() |
| else: |
| |
| if rank == 0: |
| print(f"Moving pipeline to device {device} (multi-GPU mode)") |
| pipeline = pipeline.to(device) |
| if rank == 0: |
| print("Enabling memory optimization options") |
| |
| |
| |
| if hasattr(pipeline, 'enable_attention_slicing'): |
| try: |
| pipeline.enable_attention_slicing() |
| if rank == 0: |
| print(" - Attention slicing enabled") |
| except Exception as e: |
| if rank == 0: |
| print(f" - Warning: Failed to enable attention slicing: {e}") |
| else: |
| if rank == 0: |
| print(" - Attention slicing not available for this pipeline") |
| |
| |
| |
| enable_vae_slicing_method = getattr(pipeline, 'enable_vae_slicing', None) |
| if enable_vae_slicing_method is not None and callable(enable_vae_slicing_method): |
| try: |
| enable_vae_slicing_method() |
| if rank == 0: |
| print(" - VAE slicing enabled") |
| except Exception as e: |
| if rank == 0: |
| print(f" - Warning: Failed to enable VAE slicing: {e}") |
| else: |
| if rank == 0: |
| print(" - VAE slicing not available for this pipeline (SD3 may not support this)") |
| |
| |
| if rank == 0: |
| print(f"Pipeline device verification:") |
| print(f" - Transformer device: {next(pipeline.transformer.parameters()).device}") |
| print(f" - VAE device: {next(pipeline.vae.parameters()).device}") |
| if hasattr(pipeline, 'text_encoder') and pipeline.text_encoder is not None: |
| print(f" - Text encoder device: {next(pipeline.text_encoder.parameters()).device}") |
| dist.barrier() |
| |
| |
| pipeline.set_progress_bar_config(disable=True) |
| |
| |
| folder_name = f"checkpoint-{lora_source}-rank{args.lora_rank}-guidance-{args.guidance_scale}-steps-{args.num_inference_steps}-size-{args.height}x{args.width}" |
| sample_folder_dir = os.path.join(args.sample_dir, folder_name) |
| |
| if rank == 0: |
| os.makedirs(sample_folder_dir, exist_ok=True) |
| print(f"Saving .png samples at {sample_folder_dir}") |
| |
| caption_file = os.path.join(sample_folder_dir, "captions.txt") |
| if os.path.exists(caption_file): |
| os.remove(caption_file) |
| dist.barrier() |
| |
| |
| n = args.per_proc_batch_size |
| global_batch_size = n * dist.get_world_size() |
| |
| |
| existing_samples = 0 |
| if os.path.exists(sample_folder_dir): |
| existing_samples = len([ |
| name for name in os.listdir(sample_folder_dir) |
| if os.path.isfile(os.path.join(sample_folder_dir, name)) and name.endswith(".png") |
| ]) |
| |
| total_samples = int(math.ceil(total_images_needed / global_batch_size) * global_batch_size) |
| if rank == 0: |
| print(f"Total number of images that will be sampled: {total_samples}") |
| print(f"Existing samples: {existing_samples}") |
| |
| assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size" |
| samples_needed_this_gpu = int(total_samples // dist.get_world_size()) |
| assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size" |
| |
| iterations = int(samples_needed_this_gpu // n) |
| done_iterations = int(int(existing_samples // dist.get_world_size()) // n) |
| |
| pbar = range(done_iterations, iterations) |
| pbar = tqdm(pbar) if rank == 0 else pbar |
| |
| |
| caption_image_pairs = [] |
| for i, caption in enumerate(captions): |
| for j in range(args.images_per_caption): |
| caption_image_pairs.append((caption, i, j)) |
| |
| total_generated = existing_samples |
| |
| |
| for i in pbar: |
| |
| batch_prompts = [] |
| batch_caption_info = [] |
| |
| for j in range(n): |
| global_index = i * global_batch_size + j * dist.get_world_size() + rank |
| if global_index < len(caption_image_pairs): |
| caption, caption_idx, image_idx = caption_image_pairs[global_index] |
| batch_prompts.append(caption) |
| batch_caption_info.append((caption, caption_idx, image_idx)) |
| else: |
| |
| if caption_image_pairs: |
| caption, caption_idx, image_idx = caption_image_pairs[-1] |
| batch_prompts.append(caption) |
| batch_caption_info.append((caption, caption_idx, image_idx)) |
| else: |
| batch_prompts.append("a beautiful high quality image") |
| batch_caption_info.append(("a beautiful high quality image", 0, 0)) |
| |
| |
| |
| device_str = str(device) |
| with torch.autocast(device_str, dtype=dtype): |
| |
| images = [] |
| for k, prompt in enumerate(batch_prompts): |
| |
| image_seed = seed + i * 10000 + k * 1000 + rank |
| generator = torch.Generator(device=device).manual_seed(image_seed) |
| |
| |
| if i == done_iterations and k == 0 and rank < 2: |
| print(f"[Rank {rank}] Generating image on device {device}, generator device: {generator.device}") |
| |
| image = pipeline( |
| prompt=prompt, |
| negative_prompt=args.negative_prompt if args.negative_prompt else None, |
| height=args.height, |
| width=args.width, |
| num_inference_steps=args.num_inference_steps, |
| guidance_scale=args.guidance_scale, |
| generator=generator, |
| num_images_per_prompt=1, |
| ).images[0] |
| images.append(image) |
| |
| |
| if k == len(batch_prompts) - 1: |
| torch.cuda.empty_cache() |
| |
| |
| for j, (image, (caption, caption_idx, image_idx)) in enumerate(zip(images, batch_caption_info)): |
| global_index = i * global_batch_size + j * dist.get_world_size() + rank |
| if global_index < len(caption_image_pairs): |
| |
| filename = f"{global_index:06d}_cap{caption_idx:04d}_img{image_idx:02d}.png" |
| image_path = os.path.join(sample_folder_dir, filename) |
| image.save(image_path) |
| |
| |
| if rank == 0: |
| caption_file = os.path.join(sample_folder_dir, "captions.txt") |
| with open(caption_file, "a", encoding="utf-8") as f: |
| f.write(f"{filename}\t{caption}\n") |
| |
| total_generated += global_batch_size |
| |
| |
| torch.cuda.empty_cache() |
| |
| dist.barrier() |
| |
| |
| dist.barrier() |
| |
| |
| if rank == 0: |
| |
| actual_num_samples = len([name for name in os.listdir(sample_folder_dir) if name.endswith(".png")]) |
| print(f"Actually generated {actual_num_samples} images") |
| |
| npz_samples = min(actual_num_samples, total_images_needed, args.max_samples) |
| create_npz_from_sample_folder(sample_folder_dir, npz_samples) |
| print("Done.") |
| |
| dist.barrier() |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="SD3 LoRA分布式采样脚本 - 从checkpoint加载") |
| |
| |
| parser.add_argument( |
| "--pretrained_model_name_or_path", |
| type=str, |
| default="stabilityai/stable-diffusion-3-medium-diffusers", |
| help="预训练模型路径或HuggingFace模型ID" |
| ) |
| parser.add_argument( |
| "--lora_checkpoint_path", |
| type=str, |
| required=True, |
| help="LoRA checkpoint目录路径(包含model.safetensors的目录)" |
| ) |
| parser.add_argument( |
| "--lora_rank", |
| type=int, |
| default=64, |
| help="LoRA rank(必须与训练时一致)" |
| ) |
| parser.add_argument( |
| "--revision", |
| type=str, |
| default=None, |
| help="模型修订版本" |
| ) |
| parser.add_argument( |
| "--variant", |
| type=str, |
| default=None, |
| help="模型变体,如fp16" |
| ) |
| |
| |
| parser.add_argument( |
| "--num_inference_steps", |
| type=int, |
| default=28, |
| help="推理步数" |
| ) |
| parser.add_argument( |
| "--guidance_scale", |
| type=float, |
| default=7.0, |
| help="引导尺度" |
| ) |
| parser.add_argument( |
| "--height", |
| type=int, |
| default=1024, |
| help="生成图像高度" |
| ) |
| parser.add_argument( |
| "--width", |
| type=int, |
| default=1024, |
| help="生成图像宽度" |
| ) |
| parser.add_argument( |
| "--negative_prompt", |
| type=str, |
| default="", |
| help="负面提示词" |
| ) |
| |
| |
| parser.add_argument( |
| "--per_proc_batch_size", |
| type=int, |
| default=1, |
| help="每个进程的批处理大小" |
| ) |
| parser.add_argument( |
| "--sample_dir", |
| type=str, |
| default="sd3_lora_samples", |
| help="样本保存目录" |
| ) |
| |
| |
| parser.add_argument( |
| "--captions_jsonl", |
| type=str, |
| required=True, |
| help="包含caption列表的JSONL文件路径" |
| ) |
| parser.add_argument( |
| "--images_per_caption", |
| type=int, |
| default=1, |
| help="每个caption生成的图像数量" |
| ) |
| parser.add_argument( |
| "--max_samples", |
| type=int, |
| default=30000, |
| help="最大样本生成数量" |
| ) |
| |
| |
| parser.add_argument( |
| "--global_seed", |
| type=int, |
| default=42, |
| help="全局随机种子" |
| ) |
| parser.add_argument( |
| "--mixed_precision", |
| type=str, |
| default="fp16", |
| choices=["no", "fp16", "bf16"], |
| help="混合精度类型" |
| ) |
| parser.add_argument( |
| "--enable_cpu_offload", |
| action="store_true", |
| help="启用CPU offload以节省显存" |
| ) |
| |
| args = parser.parse_args() |
| main(args) |
|
|