#!/usr/bin/env python3 """ Create Training Config Reads the master config.yaml and generates an ai-toolkit compatible YAML training config at configs/train_sdxl_lora.yaml. """ import os import sys import argparse import logging from pathlib import Path import yaml logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") logger = logging.getLogger(__name__) def load_config(config_path: str = "configs/config.yaml") -> dict: with open(config_path, "r", encoding="utf-8") as f: return yaml.safe_load(f) def generate_ai_toolkit_config(config: dict, output_path: str): """ Generate an ai-toolkit compatible training config from master config. ai-toolkit expects a specific YAML schema for training SDXL LoRA. """ sdxl_cfg = config.get("models", {}).get("sdxl", {}) training_cfg = config.get("training", {}) sdxl_lora_cfg = training_cfg.get("sdxl_lora", {}) lora_cfg = sdxl_lora_cfg.get("lora", {}) optim_cfg = sdxl_lora_cfg.get("optimizer", {}) sched_cfg = sdxl_lora_cfg.get("scheduler", {}) snr_cfg = sdxl_lora_cfg.get("min_snr_gamma", {}) paths_cfg = config.get("paths", {}) # Base model base_model = sdxl_cfg.get("repo_id", "stabilityai/stable-diffusion-xl-base-1.0") # Paths data_dir = os.path.abspath(paths_cfg.get("data", {}).get("train", "data/train")) output_dir = os.path.abspath( paths_cfg.get("models", {}).get("sdxl", {}).get("checkpoints", "models/sdxl/checkpoints") ) log_dir = os.path.abspath( paths_cfg.get("logs", {}).get("tensorboard", "logs/tensorboard") ) # LoRA params rank = lora_cfg.get("rank", 32) alpha = lora_cfg.get("alpha", 16) dropout = lora_cfg.get("dropout", 0.05) # Training params batch_size = sdxl_lora_cfg.get("batch_size", 1) grad_accum = sdxl_lora_cfg.get("gradient_accumulation_steps", 4) lr = optim_cfg.get("learning_rate", 1e-4) epochs = sdxl_lora_cfg.get("epochs", 4) max_steps = sdxl_lora_cfg.get("max_steps", 12800) warmup_steps = sched_cfg.get("warmup_steps", 100) weight_decay = optim_cfg.get("weight_decay", 0.01) betas = optim_cfg.get("betas", [0.9, 0.999]) # Resolution height = sdxl_cfg.get("height", 1024) width = sdxl_cfg.get("width", 1024) # Seed seed = config.get("project", {}).get("seed", 42) # Mixed precision mixed_prec = training_cfg.get("mixed_precision", {}) dtype = mixed_prec.get("dtype", "bf16") # Build ai-toolkit config aitk_config = { "job": "extension", "config": { "name": "campus_ai_poster_sdxl", "process": [ { "type": "sd_trainer", "training_folder": output_dir, "device": "cuda:0", "trigger_word": "campus_ai_poster", "network": { "type": "lora", "linear": rank, "linear_alpha": alpha, "dropout": dropout, "network_kwargs": { "lora_plus_lr_ratio": lora_cfg.get("lora_plus_ratio", 1.0), }, }, "save": { "dtype": dtype, "save_every": sdxl_lora_cfg.get("checkpointing", {}).get("save_steps", 500), "max_step_saves_to_keep": sdxl_lora_cfg.get("checkpointing", {}).get("save_total_limit", 5), }, "datasets": [ { "folder_path": data_dir, "caption_ext": "txt", "caption_dropout_rate": 0.1, "shuffle_tokens": True, "cache_latents_to_disk": True, "num_workers": 8, "resolution": [width, height], } ], "train": { "batch_size": batch_size, "steps": max_steps if max_steps > 0 else 12800, "gradient_accumulation_steps": grad_accum, "train_unet": True, "train_text_encoder": False, "disable_sampling": True, "gradient_checkpointing": True, "noise_scheduler": "ddpm", "optimizer": optim_cfg.get("type", "adamw8bit"), "lr": lr, "lr_warmup_steps": warmup_steps, "min_snr_gamma": snr_cfg.get("gamma", 5.0) if snr_cfg.get("enabled", True) else None, "optimizer_params": { "weight_decay": weight_decay, "betas": betas, }, "ema_config": { "use_ema": True, "ema_decay": 0.999, }, "dtype": dtype, "lr_scheduler": sched_cfg.get("type", "cosine_with_restarts"), "lr_scheduler_params": { "T_0": max(1, (max_steps if max_steps > 0 else 12800) // sched_cfg.get("num_cycles", 3)), "T_mult": 1, "eta_min": lr / 10, }, }, "model": { "name_or_path": base_model, "is_xl": True, }, "sample": { "sampler": "euler_a", "sample_every": 999999, "width": width, "height": height, "prompts": [ "campus_ai_poster a vibrant technology fest poster with neon colors and bold typography", "campus_ai_poster a colorful Diwali celebration poster with golden diyas and rangoli", "campus_ai_poster a professional workshop seminar poster with modern minimalist design", "campus_ai_poster a dynamic sports tournament poster with action silhouettes", ], "neg": "", "seed": seed, "walk_seed": True, "guidance_scale": 5, "sample_steps": 28, }, "logging": { "log_every": sdxl_lora_cfg.get("logging", {}).get("steps", 10), "use_wandb": config.get("monitoring", {}).get("wandb", {}).get("enabled", False), "verbose": True, }, } ], "meta": { "name": "campus_ai_v1", "version": "1.0", }, }, } # Write output output_file = Path(output_path) output_file.parent.mkdir(parents=True, exist_ok=True) with open(output_file, "w", encoding="utf-8") as f: yaml.dump(aitk_config, f, default_flow_style=False, sort_keys=False, allow_unicode=True) logger.info(f"ai-toolkit training config written to: {output_file}") logger.info(f" Base model: {base_model}") logger.info(f" Dataset dir: {data_dir}") logger.info(f" Output dir: {output_dir}") logger.info(f" LoRA rank: {rank}, alpha: {alpha}") logger.info(f" Batch size: {batch_size}, Grad accum: {grad_accum}") logger.info(f" Learning rate: {lr}") logger.info(f" Resolution: {width}x{height}") logger.info(f" Mixed precision: {dtype}") return aitk_config def main(): parser = argparse.ArgumentParser(description="Generate ai-toolkit Training Config") parser.add_argument("--config", default="configs/config.yaml", help="Path to master config.yaml") parser.add_argument("--output", default="configs/train_sdxl_lora.yaml", help="Output path for ai-toolkit config") args = parser.parse_args() config = load_config(args.config) generate_ai_toolkit_config(config, args.output) if __name__ == "__main__": main()