SyncAI / train_lora.py
ICGenAIShare04's picture
Upload 52 files
72f552e verified
"""SDXL LoRA training script — run on Google Colab (T4 GPU).
Trains a style LoRA on SDXL using DreamBooth with 15-20 curated images.
The trained weights (.safetensors) can then be used with image_generator_hf.py / image_generator_api.py.
Setup:
1. Open Google Colab with a T4 GPU runtime
2. Upload this script, or copy each section into separate cells
3. Upload your style images to lora_training_data/
4. Add a .txt caption file alongside each image
5. Run all cells in order
6. Download the trained .safetensors from styles/
Dataset structure:
lora_training_data/
image_001.png
image_001.txt # "a sunset landscape with mountains, in sks style"
image_002.jpg
image_002.txt # "a woman silhouetted against warm sky, in sks style"
...
"""
import json
import subprocess
import sys
from pathlib import Path
# ---------------------------------------------------------------------------
# Config — adjust these before training
# ---------------------------------------------------------------------------
# Trigger word that activates your style in prompts
TRIGGER_WORD = "sks"
INSTANCE_PROMPT = f"a photo in {TRIGGER_WORD} style"
# Training hyperparameters (tuned for 15-20 images on T4 16GB)
CONFIG = {
"base_model": "stabilityai/stable-diffusion-xl-base-1.0",
"vae": "madebyollin/sdxl-vae-fp16-fix", # fixes fp16 instability
"resolution": 1024,
"train_batch_size": 1,
"gradient_accumulation_steps": 4, # effective batch size = 4
"learning_rate": 1e-4,
"lr_scheduler": "constant",
"lr_warmup_steps": 0,
"max_train_steps": 1500, # ~100 × num_images
"rank": 16, # LoRA rank (reduced from 32 to fit T4 16GB)
"snr_gamma": 5.0, # Min-SNR weighting for stable convergence
"mixed_precision": "fp16", # T4 doesn't support bf16
"checkpointing_steps": 500,
"seed": 42,
}
# Paths
DATASET_DIR = "/content/drive/MyDrive/lora_training_data"
OUTPUT_DIR = "/content/drive/MyDrive/lora_output"
FINAL_WEIGHTS_DIR = "styles"
# ---------------------------------------------------------------------------
# 1. Install dependencies
# ---------------------------------------------------------------------------
def install_dependencies():
"""Install training dependencies (run once per Colab session)."""
# Clone diffusers for the training script
if not Path("diffusers").exists():
subprocess.check_call([
"git", "clone", "--depth", "1",
"https://github.com/huggingface/diffusers",
])
# Install diffusers from source + DreamBooth requirements
subprocess.check_call([
sys.executable, "-m", "pip", "install", "-q", "./diffusers",
])
subprocess.check_call([
sys.executable, "-m", "pip", "install", "-q",
"-r", "diffusers/examples/dreambooth/requirements.txt",
])
# Install remaining deps — peft last to ensure correct version
subprocess.check_call([
sys.executable, "-m", "pip", "install", "-q",
"transformers", "accelerate",
"bitsandbytes", "safetensors", "Pillow",
])
subprocess.check_call([
sys.executable, "-m", "pip", "install", "-q",
"peft>=0.17.0",
])
print("Dependencies installed.")
# ---------------------------------------------------------------------------
# 2. Configure accelerate
# ---------------------------------------------------------------------------
def configure_accelerate():
"""Write a single-GPU accelerate config."""
from accelerate.utils import write_basic_config
write_basic_config()
print("Accelerate configured for single GPU.")
# ---------------------------------------------------------------------------
# 3. Prepare dataset
# ---------------------------------------------------------------------------
def verify_dataset(dataset_dir: str = DATASET_DIR) -> int:
"""Verify dataset folder has images + metadata.jsonl (no .txt files).
Args:
dataset_dir: Path to folder on Google Drive.
Returns:
Number of images found.
"""
dataset_path = Path(dataset_dir)
image_extensions = {".png", ".jpg", ".jpeg", ".webp", ".bmp"}
images = [f for f in dataset_path.iterdir() if f.suffix.lower() in image_extensions]
metadata = dataset_path / "metadata.jsonl"
if not images:
raise FileNotFoundError(f"No images found in {dataset_dir}/.")
if not metadata.exists():
raise FileNotFoundError(f"metadata.jsonl not found in {dataset_dir}/.")
# Warn if .txt files are present (will cause dataset to load as text)
txt_files = [f for f in dataset_path.glob("*.txt")]
if txt_files:
raise RuntimeError(
f"Found .txt files in dataset folder: {[f.name for f in txt_files]}. "
f"Remove them — only images + metadata.jsonl should be present."
)
print(f"Dataset OK: {len(images)} images + metadata.jsonl")
return len(images)
# ---------------------------------------------------------------------------
# 4. Train
# ---------------------------------------------------------------------------
def train(
dataset_dir: str = DATASET_DIR,
output_dir: str = OUTPUT_DIR,
resume: bool = False,
):
"""Launch DreamBooth LoRA training on SDXL.
Args:
dataset_dir: Path to prepared dataset.
output_dir: Where to save checkpoints and final weights.
resume: If True, resume from the latest checkpoint.
"""
cfg = CONFIG
cmd = [
sys.executable, "-m", "accelerate.commands.launch",
"diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py",
f"--pretrained_model_name_or_path={cfg['base_model']}",
f"--pretrained_vae_model_name_or_path={cfg['vae']}",
f"--dataset_name={dataset_dir}",
"--image_column=image",
"--caption_column=prompt",
f"--output_dir={output_dir}",
f"--resolution={cfg['resolution']}",
f"--train_batch_size={cfg['train_batch_size']}",
f"--gradient_accumulation_steps={cfg['gradient_accumulation_steps']}",
"--gradient_checkpointing",
"--use_8bit_adam",
f"--mixed_precision={cfg['mixed_precision']}",
f"--learning_rate={cfg['learning_rate']}",
f"--lr_scheduler={cfg['lr_scheduler']}",
f"--lr_warmup_steps={cfg['lr_warmup_steps']}",
f"--max_train_steps={cfg['max_train_steps']}",
f"--rank={cfg['rank']}",
f"--snr_gamma={cfg['snr_gamma']}",
f"--instance_prompt={INSTANCE_PROMPT}",
f"--checkpointing_steps={cfg['checkpointing_steps']}",
f"--seed={cfg['seed']}",
]
if resume:
cmd.append("--resume_from_checkpoint=latest")
print("Starting training...")
print(f" Model: {cfg['base_model']}")
print(f" Steps: {cfg['max_train_steps']}")
print(f" Rank: {cfg['rank']}")
print(f" LR: {cfg['learning_rate']}")
print(f" Resume: {resume}")
print()
# Run with live output so progress bar and errors are visible
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
bufsize=1, text=True,
)
for line in process.stdout:
print(line, end="", flush=True)
process.wait()
if process.returncode != 0:
raise RuntimeError(f"Training failed with exit code {process.returncode}")
print(f"\nTraining complete! Weights saved to {output_dir}/")
# ---------------------------------------------------------------------------
# 5. Copy weights to styles/
# ---------------------------------------------------------------------------
def export_weights(
output_dir: str = OUTPUT_DIR,
styles_dir: str = FINAL_WEIGHTS_DIR,
style_name: str = "custom-style",
):
"""Copy trained LoRA weights to the styles directory.
Looks for final weights first, falls back to latest checkpoint.
"""
output_path = Path(output_dir)
# Try final weights first
src = output_path / "pytorch_lora_weights.safetensors"
# Fall back to latest checkpoint
if not src.exists():
checkpoints = sorted(
output_path.glob("checkpoint-*"),
key=lambda p: int(p.name.split("-")[1]),
)
if checkpoints:
latest = checkpoints[-1]
# Check common checkpoint weight locations
for candidate in [
latest / "pytorch_lora_weights.safetensors",
latest / "unet" / "adapter_model.safetensors",
]:
if candidate.exists():
src = candidate
print(f"Using checkpoint: {latest.name}")
break
if not src.exists():
raise FileNotFoundError(
f"No weights found in {output_dir}/. "
f"Check that training completed or a checkpoint was saved."
)
dst_dir = Path(styles_dir)
dst_dir.mkdir(parents=True, exist_ok=True)
dst = dst_dir / f"{style_name}.safetensors"
import shutil
shutil.copy2(src, dst)
size_mb = dst.stat().st_size / (1024 * 1024)
print(f"Exported weights: {dst} ({size_mb:.1f} MB)")
print(f"Download this file and place it in your project's styles/ folder.")
# ---------------------------------------------------------------------------
# 6. Backup to Google Drive
# ---------------------------------------------------------------------------
def backup_to_drive(output_dir: str = OUTPUT_DIR):
"""Copy training output to Google Drive for safety.
Note: If OUTPUT_DIR already points to Drive, this is a no-op.
"""
drive_path = Path("/content/drive/MyDrive/lora_output")
if Path(output_dir).resolve() == drive_path.resolve():
print("Output already on Google Drive — no backup needed.")
return
if not Path("/content/drive/MyDrive").exists():
from google.colab import drive
drive.mount("/content/drive")
import shutil
shutil.copytree(output_dir, str(drive_path), dirs_exist_ok=True)
print(f"Backed up to {drive_path}")
# ---------------------------------------------------------------------------
# 7. Test inference
# ---------------------------------------------------------------------------
def test_inference(
output_dir: str = OUTPUT_DIR,
prompt: str = None,
):
"""Generate a test image with the trained LoRA + Hyper-SD to verify quality.
Uses the same setup as image_generator_hf.py for accurate results.
"""
import torch
from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline
from huggingface_hub import hf_hub_download
if prompt is None:
prompt = f"a serene mountain landscape at golden hour, in {TRIGGER_WORD} style"
print("Loading model + LoRA for test inference...")
vae = AutoencoderKL.from_pretrained(
CONFIG["vae"], torch_dtype=torch.float16,
)
pipe = DiffusionPipeline.from_pretrained(
CONFIG["base_model"],
vae=vae,
torch_dtype=torch.float16,
variant="fp16",
).to("cuda")
# Load Hyper-SD (same as image_generator_hf.py)
hyper_path = hf_hub_download(
"ByteDance/Hyper-SD", "Hyper-SDXL-8steps-CFG-lora.safetensors",
)
pipe.load_lora_weights(hyper_path, adapter_name="hyper-sd")
# Load trained style LoRA (check final weights, then latest checkpoint)
output_path = Path(output_dir)
weights_file = output_path / "pytorch_lora_weights.safetensors"
if not weights_file.exists():
checkpoints = sorted(
output_path.glob("checkpoint-*"),
key=lambda p: int(p.name.split("-")[1]),
)
if checkpoints:
weights_file = checkpoints[-1] / "pytorch_lora_weights.safetensors"
pipe.load_lora_weights(
str(weights_file.parent),
weight_name=weights_file.name,
adapter_name="style",
)
pipe.set_adapters(
["hyper-sd", "style"],
adapter_weights=[0.125, 1.0],
)
pipe.scheduler = DDIMScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing",
)
image = pipe(
prompt=prompt,
negative_prompt="blurry, low quality, deformed, ugly, text, watermark",
num_inference_steps=8,
guidance_scale=5.0,
height=1344,
width=768,
).images[0]
image.save("test_output.png")
print(f"Test image saved to test_output.png")
print(f"Prompt: {prompt}")
return image
# ---------------------------------------------------------------------------
# Main — run all steps in sequence
# ---------------------------------------------------------------------------
if __name__ == "__main__":
print("=" * 60)
print("SDXL LoRA Training Pipeline")
print("=" * 60)
# Step 1: Install
install_dependencies()
# Step 2: Configure
configure_accelerate()
# Step 3: Verify dataset
num_images = verify_dataset()
steps = max(1500, num_images * 100)
CONFIG["max_train_steps"] = steps
print(f"Adjusted training steps to {steps} ({num_images} images × 100)")
# Step 4: Train
train()
# Step 5: Backup
backup_to_drive()
# Step 6: Export
export_weights(style_name="custom-style")
# Step 7: Test
test_inference()
print("\nDone! Download styles/custom-style.safetensors")