| | """SDXL LoRA training script — run on Google Colab (T4 GPU). |
| | |
| | Trains a style LoRA on SDXL using DreamBooth with 15-20 curated images. |
| | The trained weights (.safetensors) can then be used with image_generator_hf.py / image_generator_api.py. |
| | |
| | Setup: |
| | 1. Open Google Colab with a T4 GPU runtime |
| | 2. Upload this script, or copy each section into separate cells |
| | 3. Upload your style images to lora_training_data/ |
| | 4. Add a .txt caption file alongside each image |
| | 5. Run all cells in order |
| | 6. Download the trained .safetensors from styles/ |
| | |
| | Dataset structure: |
| | lora_training_data/ |
| | image_001.png |
| | image_001.txt # "a sunset landscape with mountains, in sks style" |
| | image_002.jpg |
| | image_002.txt # "a woman silhouetted against warm sky, in sks style" |
| | ... |
| | """ |
| |
|
| | import json |
| | import subprocess |
| | import sys |
| | from pathlib import Path |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | TRIGGER_WORD = "sks" |
| | INSTANCE_PROMPT = f"a photo in {TRIGGER_WORD} style" |
| |
|
| | |
| | CONFIG = { |
| | "base_model": "stabilityai/stable-diffusion-xl-base-1.0", |
| | "vae": "madebyollin/sdxl-vae-fp16-fix", |
| | "resolution": 1024, |
| | "train_batch_size": 1, |
| | "gradient_accumulation_steps": 4, |
| | "learning_rate": 1e-4, |
| | "lr_scheduler": "constant", |
| | "lr_warmup_steps": 0, |
| | "max_train_steps": 1500, |
| | "rank": 16, |
| | "snr_gamma": 5.0, |
| | "mixed_precision": "fp16", |
| | "checkpointing_steps": 500, |
| | "seed": 42, |
| | } |
| |
|
| | |
| | DATASET_DIR = "/content/drive/MyDrive/lora_training_data" |
| | OUTPUT_DIR = "/content/drive/MyDrive/lora_output" |
| | FINAL_WEIGHTS_DIR = "styles" |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def install_dependencies(): |
| | """Install training dependencies (run once per Colab session).""" |
| | |
| | if not Path("diffusers").exists(): |
| | subprocess.check_call([ |
| | "git", "clone", "--depth", "1", |
| | "https://github.com/huggingface/diffusers", |
| | ]) |
| |
|
| | |
| | subprocess.check_call([ |
| | sys.executable, "-m", "pip", "install", "-q", "./diffusers", |
| | ]) |
| | subprocess.check_call([ |
| | sys.executable, "-m", "pip", "install", "-q", |
| | "-r", "diffusers/examples/dreambooth/requirements.txt", |
| | ]) |
| |
|
| | |
| | subprocess.check_call([ |
| | sys.executable, "-m", "pip", "install", "-q", |
| | "transformers", "accelerate", |
| | "bitsandbytes", "safetensors", "Pillow", |
| | ]) |
| | subprocess.check_call([ |
| | sys.executable, "-m", "pip", "install", "-q", |
| | "peft>=0.17.0", |
| | ]) |
| |
|
| | print("Dependencies installed.") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def configure_accelerate(): |
| | """Write a single-GPU accelerate config.""" |
| | from accelerate.utils import write_basic_config |
| |
|
| | write_basic_config() |
| | print("Accelerate configured for single GPU.") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def verify_dataset(dataset_dir: str = DATASET_DIR) -> int: |
| | """Verify dataset folder has images + metadata.jsonl (no .txt files). |
| | |
| | Args: |
| | dataset_dir: Path to folder on Google Drive. |
| | |
| | Returns: |
| | Number of images found. |
| | """ |
| | dataset_path = Path(dataset_dir) |
| | image_extensions = {".png", ".jpg", ".jpeg", ".webp", ".bmp"} |
| |
|
| | images = [f for f in dataset_path.iterdir() if f.suffix.lower() in image_extensions] |
| | metadata = dataset_path / "metadata.jsonl" |
| |
|
| | if not images: |
| | raise FileNotFoundError(f"No images found in {dataset_dir}/.") |
| | if not metadata.exists(): |
| | raise FileNotFoundError(f"metadata.jsonl not found in {dataset_dir}/.") |
| |
|
| | |
| | txt_files = [f for f in dataset_path.glob("*.txt")] |
| | if txt_files: |
| | raise RuntimeError( |
| | f"Found .txt files in dataset folder: {[f.name for f in txt_files]}. " |
| | f"Remove them — only images + metadata.jsonl should be present." |
| | ) |
| |
|
| | print(f"Dataset OK: {len(images)} images + metadata.jsonl") |
| | return len(images) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def train( |
| | dataset_dir: str = DATASET_DIR, |
| | output_dir: str = OUTPUT_DIR, |
| | resume: bool = False, |
| | ): |
| | """Launch DreamBooth LoRA training on SDXL. |
| | |
| | Args: |
| | dataset_dir: Path to prepared dataset. |
| | output_dir: Where to save checkpoints and final weights. |
| | resume: If True, resume from the latest checkpoint. |
| | """ |
| | cfg = CONFIG |
| |
|
| | cmd = [ |
| | sys.executable, "-m", "accelerate.commands.launch", |
| | "diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py", |
| | f"--pretrained_model_name_or_path={cfg['base_model']}", |
| | f"--pretrained_vae_model_name_or_path={cfg['vae']}", |
| | f"--dataset_name={dataset_dir}", |
| | "--image_column=image", |
| | "--caption_column=prompt", |
| | f"--output_dir={output_dir}", |
| | f"--resolution={cfg['resolution']}", |
| | f"--train_batch_size={cfg['train_batch_size']}", |
| | f"--gradient_accumulation_steps={cfg['gradient_accumulation_steps']}", |
| | "--gradient_checkpointing", |
| | "--use_8bit_adam", |
| | f"--mixed_precision={cfg['mixed_precision']}", |
| | f"--learning_rate={cfg['learning_rate']}", |
| | f"--lr_scheduler={cfg['lr_scheduler']}", |
| | f"--lr_warmup_steps={cfg['lr_warmup_steps']}", |
| | f"--max_train_steps={cfg['max_train_steps']}", |
| | f"--rank={cfg['rank']}", |
| | f"--snr_gamma={cfg['snr_gamma']}", |
| | f"--instance_prompt={INSTANCE_PROMPT}", |
| | f"--checkpointing_steps={cfg['checkpointing_steps']}", |
| | f"--seed={cfg['seed']}", |
| | ] |
| |
|
| | if resume: |
| | cmd.append("--resume_from_checkpoint=latest") |
| |
|
| | print("Starting training...") |
| | print(f" Model: {cfg['base_model']}") |
| | print(f" Steps: {cfg['max_train_steps']}") |
| | print(f" Rank: {cfg['rank']}") |
| | print(f" LR: {cfg['learning_rate']}") |
| | print(f" Resume: {resume}") |
| | print() |
| |
|
| | |
| | process = subprocess.Popen( |
| | cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, |
| | bufsize=1, text=True, |
| | ) |
| | for line in process.stdout: |
| | print(line, end="", flush=True) |
| | process.wait() |
| | if process.returncode != 0: |
| | raise RuntimeError(f"Training failed with exit code {process.returncode}") |
| |
|
| | print(f"\nTraining complete! Weights saved to {output_dir}/") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def export_weights( |
| | output_dir: str = OUTPUT_DIR, |
| | styles_dir: str = FINAL_WEIGHTS_DIR, |
| | style_name: str = "custom-style", |
| | ): |
| | """Copy trained LoRA weights to the styles directory. |
| | |
| | Looks for final weights first, falls back to latest checkpoint. |
| | """ |
| | output_path = Path(output_dir) |
| |
|
| | |
| | src = output_path / "pytorch_lora_weights.safetensors" |
| |
|
| | |
| | if not src.exists(): |
| | checkpoints = sorted( |
| | output_path.glob("checkpoint-*"), |
| | key=lambda p: int(p.name.split("-")[1]), |
| | ) |
| | if checkpoints: |
| | latest = checkpoints[-1] |
| | |
| | for candidate in [ |
| | latest / "pytorch_lora_weights.safetensors", |
| | latest / "unet" / "adapter_model.safetensors", |
| | ]: |
| | if candidate.exists(): |
| | src = candidate |
| | print(f"Using checkpoint: {latest.name}") |
| | break |
| |
|
| | if not src.exists(): |
| | raise FileNotFoundError( |
| | f"No weights found in {output_dir}/. " |
| | f"Check that training completed or a checkpoint was saved." |
| | ) |
| |
|
| | dst_dir = Path(styles_dir) |
| | dst_dir.mkdir(parents=True, exist_ok=True) |
| | dst = dst_dir / f"{style_name}.safetensors" |
| |
|
| | import shutil |
| | shutil.copy2(src, dst) |
| |
|
| | size_mb = dst.stat().st_size / (1024 * 1024) |
| | print(f"Exported weights: {dst} ({size_mb:.1f} MB)") |
| | print(f"Download this file and place it in your project's styles/ folder.") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def backup_to_drive(output_dir: str = OUTPUT_DIR): |
| | """Copy training output to Google Drive for safety. |
| | |
| | Note: If OUTPUT_DIR already points to Drive, this is a no-op. |
| | """ |
| | drive_path = Path("/content/drive/MyDrive/lora_output") |
| |
|
| | if Path(output_dir).resolve() == drive_path.resolve(): |
| | print("Output already on Google Drive — no backup needed.") |
| | return |
| |
|
| | if not Path("/content/drive/MyDrive").exists(): |
| | from google.colab import drive |
| | drive.mount("/content/drive") |
| |
|
| | import shutil |
| | shutil.copytree(output_dir, str(drive_path), dirs_exist_ok=True) |
| | print(f"Backed up to {drive_path}") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def test_inference( |
| | output_dir: str = OUTPUT_DIR, |
| | prompt: str = None, |
| | ): |
| | """Generate a test image with the trained LoRA + Hyper-SD to verify quality. |
| | |
| | Uses the same setup as image_generator_hf.py for accurate results. |
| | """ |
| | import torch |
| | from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline |
| | from huggingface_hub import hf_hub_download |
| |
|
| | if prompt is None: |
| | prompt = f"a serene mountain landscape at golden hour, in {TRIGGER_WORD} style" |
| |
|
| | print("Loading model + LoRA for test inference...") |
| |
|
| | vae = AutoencoderKL.from_pretrained( |
| | CONFIG["vae"], torch_dtype=torch.float16, |
| | ) |
| |
|
| | pipe = DiffusionPipeline.from_pretrained( |
| | CONFIG["base_model"], |
| | vae=vae, |
| | torch_dtype=torch.float16, |
| | variant="fp16", |
| | ).to("cuda") |
| |
|
| | |
| | hyper_path = hf_hub_download( |
| | "ByteDance/Hyper-SD", "Hyper-SDXL-8steps-CFG-lora.safetensors", |
| | ) |
| | pipe.load_lora_weights(hyper_path, adapter_name="hyper-sd") |
| |
|
| | |
| | output_path = Path(output_dir) |
| | weights_file = output_path / "pytorch_lora_weights.safetensors" |
| | if not weights_file.exists(): |
| | checkpoints = sorted( |
| | output_path.glob("checkpoint-*"), |
| | key=lambda p: int(p.name.split("-")[1]), |
| | ) |
| | if checkpoints: |
| | weights_file = checkpoints[-1] / "pytorch_lora_weights.safetensors" |
| | pipe.load_lora_weights( |
| | str(weights_file.parent), |
| | weight_name=weights_file.name, |
| | adapter_name="style", |
| | ) |
| |
|
| | pipe.set_adapters( |
| | ["hyper-sd", "style"], |
| | adapter_weights=[0.125, 1.0], |
| | ) |
| |
|
| | pipe.scheduler = DDIMScheduler.from_config( |
| | pipe.scheduler.config, timestep_spacing="trailing", |
| | ) |
| |
|
| | image = pipe( |
| | prompt=prompt, |
| | negative_prompt="blurry, low quality, deformed, ugly, text, watermark", |
| | num_inference_steps=8, |
| | guidance_scale=5.0, |
| | height=1344, |
| | width=768, |
| | ).images[0] |
| |
|
| | image.save("test_output.png") |
| | print(f"Test image saved to test_output.png") |
| | print(f"Prompt: {prompt}") |
| |
|
| | return image |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | if __name__ == "__main__": |
| | print("=" * 60) |
| | print("SDXL LoRA Training Pipeline") |
| | print("=" * 60) |
| |
|
| | |
| | install_dependencies() |
| |
|
| | |
| | configure_accelerate() |
| |
|
| | |
| | num_images = verify_dataset() |
| | steps = max(1500, num_images * 100) |
| | CONFIG["max_train_steps"] = steps |
| | print(f"Adjusted training steps to {steps} ({num_images} images × 100)") |
| |
|
| | |
| | train() |
| |
|
| | |
| | backup_to_drive() |
| |
|
| | |
| | export_weights(style_name="custom-style") |
| |
|
| | |
| | test_inference() |
| |
|
| | print("\nDone! Download styles/custom-style.safetensors") |
| |
|