"""SDXL LoRA training script — run on Google Colab (T4 GPU). Trains a style LoRA on SDXL using DreamBooth with 15-20 curated images. The trained weights (.safetensors) can then be used with image_generator_hf.py / image_generator_api.py. Setup: 1. Open Google Colab with a T4 GPU runtime 2. Upload this script, or copy each section into separate cells 3. Upload your style images to lora_training_data/ 4. Add a .txt caption file alongside each image 5. Run all cells in order 6. Download the trained .safetensors from styles/ Dataset structure: lora_training_data/ image_001.png image_001.txt # "a sunset landscape with mountains, in sks style" image_002.jpg image_002.txt # "a woman silhouetted against warm sky, in sks style" ... """ import json import subprocess import sys from pathlib import Path # --------------------------------------------------------------------------- # Config — adjust these before training # --------------------------------------------------------------------------- # Trigger word that activates your style in prompts TRIGGER_WORD = "sks" INSTANCE_PROMPT = f"a photo in {TRIGGER_WORD} style" # Training hyperparameters (tuned for 15-20 images on T4 16GB) CONFIG = { "base_model": "stabilityai/stable-diffusion-xl-base-1.0", "vae": "madebyollin/sdxl-vae-fp16-fix", # fixes fp16 instability "resolution": 1024, "train_batch_size": 1, "gradient_accumulation_steps": 4, # effective batch size = 4 "learning_rate": 1e-4, "lr_scheduler": "constant", "lr_warmup_steps": 0, "max_train_steps": 1500, # ~100 × num_images "rank": 16, # LoRA rank (reduced from 32 to fit T4 16GB) "snr_gamma": 5.0, # Min-SNR weighting for stable convergence "mixed_precision": "fp16", # T4 doesn't support bf16 "checkpointing_steps": 500, "seed": 42, } # Paths DATASET_DIR = "/content/drive/MyDrive/lora_training_data" OUTPUT_DIR = "/content/drive/MyDrive/lora_output" FINAL_WEIGHTS_DIR = "styles" # --------------------------------------------------------------------------- # 1. Install dependencies # --------------------------------------------------------------------------- def install_dependencies(): """Install training dependencies (run once per Colab session).""" # Clone diffusers for the training script if not Path("diffusers").exists(): subprocess.check_call([ "git", "clone", "--depth", "1", "https://github.com/huggingface/diffusers", ]) # Install diffusers from source + DreamBooth requirements subprocess.check_call([ sys.executable, "-m", "pip", "install", "-q", "./diffusers", ]) subprocess.check_call([ sys.executable, "-m", "pip", "install", "-q", "-r", "diffusers/examples/dreambooth/requirements.txt", ]) # Install remaining deps — peft last to ensure correct version subprocess.check_call([ sys.executable, "-m", "pip", "install", "-q", "transformers", "accelerate", "bitsandbytes", "safetensors", "Pillow", ]) subprocess.check_call([ sys.executable, "-m", "pip", "install", "-q", "peft>=0.17.0", ]) print("Dependencies installed.") # --------------------------------------------------------------------------- # 2. Configure accelerate # --------------------------------------------------------------------------- def configure_accelerate(): """Write a single-GPU accelerate config.""" from accelerate.utils import write_basic_config write_basic_config() print("Accelerate configured for single GPU.") # --------------------------------------------------------------------------- # 3. Prepare dataset # --------------------------------------------------------------------------- def verify_dataset(dataset_dir: str = DATASET_DIR) -> int: """Verify dataset folder has images + metadata.jsonl (no .txt files). Args: dataset_dir: Path to folder on Google Drive. Returns: Number of images found. """ dataset_path = Path(dataset_dir) image_extensions = {".png", ".jpg", ".jpeg", ".webp", ".bmp"} images = [f for f in dataset_path.iterdir() if f.suffix.lower() in image_extensions] metadata = dataset_path / "metadata.jsonl" if not images: raise FileNotFoundError(f"No images found in {dataset_dir}/.") if not metadata.exists(): raise FileNotFoundError(f"metadata.jsonl not found in {dataset_dir}/.") # Warn if .txt files are present (will cause dataset to load as text) txt_files = [f for f in dataset_path.glob("*.txt")] if txt_files: raise RuntimeError( f"Found .txt files in dataset folder: {[f.name for f in txt_files]}. " f"Remove them — only images + metadata.jsonl should be present." ) print(f"Dataset OK: {len(images)} images + metadata.jsonl") return len(images) # --------------------------------------------------------------------------- # 4. Train # --------------------------------------------------------------------------- def train( dataset_dir: str = DATASET_DIR, output_dir: str = OUTPUT_DIR, resume: bool = False, ): """Launch DreamBooth LoRA training on SDXL. Args: dataset_dir: Path to prepared dataset. output_dir: Where to save checkpoints and final weights. resume: If True, resume from the latest checkpoint. """ cfg = CONFIG cmd = [ sys.executable, "-m", "accelerate.commands.launch", "diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py", f"--pretrained_model_name_or_path={cfg['base_model']}", f"--pretrained_vae_model_name_or_path={cfg['vae']}", f"--dataset_name={dataset_dir}", "--image_column=image", "--caption_column=prompt", f"--output_dir={output_dir}", f"--resolution={cfg['resolution']}", f"--train_batch_size={cfg['train_batch_size']}", f"--gradient_accumulation_steps={cfg['gradient_accumulation_steps']}", "--gradient_checkpointing", "--use_8bit_adam", f"--mixed_precision={cfg['mixed_precision']}", f"--learning_rate={cfg['learning_rate']}", f"--lr_scheduler={cfg['lr_scheduler']}", f"--lr_warmup_steps={cfg['lr_warmup_steps']}", f"--max_train_steps={cfg['max_train_steps']}", f"--rank={cfg['rank']}", f"--snr_gamma={cfg['snr_gamma']}", f"--instance_prompt={INSTANCE_PROMPT}", f"--checkpointing_steps={cfg['checkpointing_steps']}", f"--seed={cfg['seed']}", ] if resume: cmd.append("--resume_from_checkpoint=latest") print("Starting training...") print(f" Model: {cfg['base_model']}") print(f" Steps: {cfg['max_train_steps']}") print(f" Rank: {cfg['rank']}") print(f" LR: {cfg['learning_rate']}") print(f" Resume: {resume}") print() # Run with live output so progress bar and errors are visible process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=1, text=True, ) for line in process.stdout: print(line, end="", flush=True) process.wait() if process.returncode != 0: raise RuntimeError(f"Training failed with exit code {process.returncode}") print(f"\nTraining complete! Weights saved to {output_dir}/") # --------------------------------------------------------------------------- # 5. Copy weights to styles/ # --------------------------------------------------------------------------- def export_weights( output_dir: str = OUTPUT_DIR, styles_dir: str = FINAL_WEIGHTS_DIR, style_name: str = "custom-style", ): """Copy trained LoRA weights to the styles directory. Looks for final weights first, falls back to latest checkpoint. """ output_path = Path(output_dir) # Try final weights first src = output_path / "pytorch_lora_weights.safetensors" # Fall back to latest checkpoint if not src.exists(): checkpoints = sorted( output_path.glob("checkpoint-*"), key=lambda p: int(p.name.split("-")[1]), ) if checkpoints: latest = checkpoints[-1] # Check common checkpoint weight locations for candidate in [ latest / "pytorch_lora_weights.safetensors", latest / "unet" / "adapter_model.safetensors", ]: if candidate.exists(): src = candidate print(f"Using checkpoint: {latest.name}") break if not src.exists(): raise FileNotFoundError( f"No weights found in {output_dir}/. " f"Check that training completed or a checkpoint was saved." ) dst_dir = Path(styles_dir) dst_dir.mkdir(parents=True, exist_ok=True) dst = dst_dir / f"{style_name}.safetensors" import shutil shutil.copy2(src, dst) size_mb = dst.stat().st_size / (1024 * 1024) print(f"Exported weights: {dst} ({size_mb:.1f} MB)") print(f"Download this file and place it in your project's styles/ folder.") # --------------------------------------------------------------------------- # 6. Backup to Google Drive # --------------------------------------------------------------------------- def backup_to_drive(output_dir: str = OUTPUT_DIR): """Copy training output to Google Drive for safety. Note: If OUTPUT_DIR already points to Drive, this is a no-op. """ drive_path = Path("/content/drive/MyDrive/lora_output") if Path(output_dir).resolve() == drive_path.resolve(): print("Output already on Google Drive — no backup needed.") return if not Path("/content/drive/MyDrive").exists(): from google.colab import drive drive.mount("/content/drive") import shutil shutil.copytree(output_dir, str(drive_path), dirs_exist_ok=True) print(f"Backed up to {drive_path}") # --------------------------------------------------------------------------- # 7. Test inference # --------------------------------------------------------------------------- def test_inference( output_dir: str = OUTPUT_DIR, prompt: str = None, ): """Generate a test image with the trained LoRA + Hyper-SD to verify quality. Uses the same setup as image_generator_hf.py for accurate results. """ import torch from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline from huggingface_hub import hf_hub_download if prompt is None: prompt = f"a serene mountain landscape at golden hour, in {TRIGGER_WORD} style" print("Loading model + LoRA for test inference...") vae = AutoencoderKL.from_pretrained( CONFIG["vae"], torch_dtype=torch.float16, ) pipe = DiffusionPipeline.from_pretrained( CONFIG["base_model"], vae=vae, torch_dtype=torch.float16, variant="fp16", ).to("cuda") # Load Hyper-SD (same as image_generator_hf.py) hyper_path = hf_hub_download( "ByteDance/Hyper-SD", "Hyper-SDXL-8steps-CFG-lora.safetensors", ) pipe.load_lora_weights(hyper_path, adapter_name="hyper-sd") # Load trained style LoRA (check final weights, then latest checkpoint) output_path = Path(output_dir) weights_file = output_path / "pytorch_lora_weights.safetensors" if not weights_file.exists(): checkpoints = sorted( output_path.glob("checkpoint-*"), key=lambda p: int(p.name.split("-")[1]), ) if checkpoints: weights_file = checkpoints[-1] / "pytorch_lora_weights.safetensors" pipe.load_lora_weights( str(weights_file.parent), weight_name=weights_file.name, adapter_name="style", ) pipe.set_adapters( ["hyper-sd", "style"], adapter_weights=[0.125, 1.0], ) pipe.scheduler = DDIMScheduler.from_config( pipe.scheduler.config, timestep_spacing="trailing", ) image = pipe( prompt=prompt, negative_prompt="blurry, low quality, deformed, ugly, text, watermark", num_inference_steps=8, guidance_scale=5.0, height=1344, width=768, ).images[0] image.save("test_output.png") print(f"Test image saved to test_output.png") print(f"Prompt: {prompt}") return image # --------------------------------------------------------------------------- # Main — run all steps in sequence # --------------------------------------------------------------------------- if __name__ == "__main__": print("=" * 60) print("SDXL LoRA Training Pipeline") print("=" * 60) # Step 1: Install install_dependencies() # Step 2: Configure configure_accelerate() # Step 3: Verify dataset num_images = verify_dataset() steps = max(1500, num_images * 100) CONFIG["max_train_steps"] = steps print(f"Adjusted training steps to {steps} ({num_images} images × 100)") # Step 4: Train train() # Step 5: Backup backup_to_drive() # Step 6: Export export_weights(style_name="custom-style") # Step 7: Test test_inference() print("\nDone! Download styles/custom-style.safetensors")