File size: 13,359 Bytes
72f552e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
"""SDXL LoRA training script — run on Google Colab (T4 GPU).

Trains a style LoRA on SDXL using DreamBooth with 15-20 curated images.
The trained weights (.safetensors) can then be used with image_generator_hf.py / image_generator_api.py.

Setup:
    1. Open Google Colab with a T4 GPU runtime
    2. Upload this script, or copy each section into separate cells
    3. Upload your style images to lora_training_data/
    4. Add a .txt caption file alongside each image
    5. Run all cells in order
    6. Download the trained .safetensors from styles/

Dataset structure:
    lora_training_data/
        image_001.png
        image_001.txt    # "a sunset landscape with mountains, in sks style"
        image_002.jpg
        image_002.txt    # "a woman silhouetted against warm sky, in sks style"
        ...
"""

import json
import subprocess
import sys
from pathlib import Path


# ---------------------------------------------------------------------------
# Config — adjust these before training
# ---------------------------------------------------------------------------

# Trigger word that activates your style in prompts
TRIGGER_WORD = "sks"
INSTANCE_PROMPT = f"a photo in {TRIGGER_WORD} style"

# Training hyperparameters (tuned for 15-20 images on T4 16GB)
CONFIG = {
    "base_model": "stabilityai/stable-diffusion-xl-base-1.0",
    "vae": "madebyollin/sdxl-vae-fp16-fix",  # fixes fp16 instability
    "resolution": 1024,
    "train_batch_size": 1,
    "gradient_accumulation_steps": 4,  # effective batch size = 4
    "learning_rate": 1e-4,
    "lr_scheduler": "constant",
    "lr_warmup_steps": 0,
    "max_train_steps": 1500,  # ~100 × num_images
    "rank": 16,  # LoRA rank (reduced from 32 to fit T4 16GB)
    "snr_gamma": 5.0,  # Min-SNR weighting for stable convergence
    "mixed_precision": "fp16",  # T4 doesn't support bf16
    "checkpointing_steps": 500,
    "seed": 42,
}

# Paths
DATASET_DIR = "/content/drive/MyDrive/lora_training_data"
OUTPUT_DIR = "/content/drive/MyDrive/lora_output"
FINAL_WEIGHTS_DIR = "styles"


# ---------------------------------------------------------------------------
# 1. Install dependencies
# ---------------------------------------------------------------------------

def install_dependencies():
    """Install training dependencies (run once per Colab session)."""
    # Clone diffusers for the training script
    if not Path("diffusers").exists():
        subprocess.check_call([
            "git", "clone", "--depth", "1",
            "https://github.com/huggingface/diffusers",
        ])

    # Install diffusers from source + DreamBooth requirements
    subprocess.check_call([
        sys.executable, "-m", "pip", "install", "-q", "./diffusers",
    ])
    subprocess.check_call([
        sys.executable, "-m", "pip", "install", "-q",
        "-r", "diffusers/examples/dreambooth/requirements.txt",
    ])

    # Install remaining deps — peft last to ensure correct version
    subprocess.check_call([
        sys.executable, "-m", "pip", "install", "-q",
        "transformers", "accelerate",
        "bitsandbytes", "safetensors", "Pillow",
    ])
    subprocess.check_call([
        sys.executable, "-m", "pip", "install", "-q",
        "peft>=0.17.0",
    ])

    print("Dependencies installed.")


# ---------------------------------------------------------------------------
# 2. Configure accelerate
# ---------------------------------------------------------------------------

def configure_accelerate():
    """Write a single-GPU accelerate config."""
    from accelerate.utils import write_basic_config

    write_basic_config()
    print("Accelerate configured for single GPU.")


# ---------------------------------------------------------------------------
# 3. Prepare dataset
# ---------------------------------------------------------------------------

def verify_dataset(dataset_dir: str = DATASET_DIR) -> int:
    """Verify dataset folder has images + metadata.jsonl (no .txt files).

    Args:
        dataset_dir: Path to folder on Google Drive.

    Returns:
        Number of images found.
    """
    dataset_path = Path(dataset_dir)
    image_extensions = {".png", ".jpg", ".jpeg", ".webp", ".bmp"}

    images = [f for f in dataset_path.iterdir() if f.suffix.lower() in image_extensions]
    metadata = dataset_path / "metadata.jsonl"

    if not images:
        raise FileNotFoundError(f"No images found in {dataset_dir}/.")
    if not metadata.exists():
        raise FileNotFoundError(f"metadata.jsonl not found in {dataset_dir}/.")

    # Warn if .txt files are present (will cause dataset to load as text)
    txt_files = [f for f in dataset_path.glob("*.txt")]
    if txt_files:
        raise RuntimeError(
            f"Found .txt files in dataset folder: {[f.name for f in txt_files]}. "
            f"Remove them — only images + metadata.jsonl should be present."
        )

    print(f"Dataset OK: {len(images)} images + metadata.jsonl")
    return len(images)


# ---------------------------------------------------------------------------
# 4. Train
# ---------------------------------------------------------------------------

def train(
    dataset_dir: str = DATASET_DIR,
    output_dir: str = OUTPUT_DIR,
    resume: bool = False,
):
    """Launch DreamBooth LoRA training on SDXL.

    Args:
        dataset_dir: Path to prepared dataset.
        output_dir: Where to save checkpoints and final weights.
        resume: If True, resume from the latest checkpoint.
    """
    cfg = CONFIG

    cmd = [
        sys.executable, "-m", "accelerate.commands.launch",
        "diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py",
        f"--pretrained_model_name_or_path={cfg['base_model']}",
        f"--pretrained_vae_model_name_or_path={cfg['vae']}",
        f"--dataset_name={dataset_dir}",
        "--image_column=image",
        "--caption_column=prompt",
        f"--output_dir={output_dir}",
        f"--resolution={cfg['resolution']}",
        f"--train_batch_size={cfg['train_batch_size']}",
        f"--gradient_accumulation_steps={cfg['gradient_accumulation_steps']}",
        "--gradient_checkpointing",
        "--use_8bit_adam",
        f"--mixed_precision={cfg['mixed_precision']}",
        f"--learning_rate={cfg['learning_rate']}",
        f"--lr_scheduler={cfg['lr_scheduler']}",
        f"--lr_warmup_steps={cfg['lr_warmup_steps']}",
        f"--max_train_steps={cfg['max_train_steps']}",
        f"--rank={cfg['rank']}",
        f"--snr_gamma={cfg['snr_gamma']}",
        f"--instance_prompt={INSTANCE_PROMPT}",
        f"--checkpointing_steps={cfg['checkpointing_steps']}",
        f"--seed={cfg['seed']}",
    ]

    if resume:
        cmd.append("--resume_from_checkpoint=latest")

    print("Starting training...")
    print(f"  Model: {cfg['base_model']}")
    print(f"  Steps: {cfg['max_train_steps']}")
    print(f"  Rank:  {cfg['rank']}")
    print(f"  LR:    {cfg['learning_rate']}")
    print(f"  Resume: {resume}")
    print()

    # Run with live output so progress bar and errors are visible
    process = subprocess.Popen(
        cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
        bufsize=1, text=True,
    )
    for line in process.stdout:
        print(line, end="", flush=True)
    process.wait()
    if process.returncode != 0:
        raise RuntimeError(f"Training failed with exit code {process.returncode}")

    print(f"\nTraining complete! Weights saved to {output_dir}/")


# ---------------------------------------------------------------------------
# 5. Copy weights to styles/
# ---------------------------------------------------------------------------

def export_weights(
    output_dir: str = OUTPUT_DIR,
    styles_dir: str = FINAL_WEIGHTS_DIR,
    style_name: str = "custom-style",
):
    """Copy trained LoRA weights to the styles directory.

    Looks for final weights first, falls back to latest checkpoint.
    """
    output_path = Path(output_dir)

    # Try final weights first
    src = output_path / "pytorch_lora_weights.safetensors"

    # Fall back to latest checkpoint
    if not src.exists():
        checkpoints = sorted(
            output_path.glob("checkpoint-*"),
            key=lambda p: int(p.name.split("-")[1]),
        )
        if checkpoints:
            latest = checkpoints[-1]
            # Check common checkpoint weight locations
            for candidate in [
                latest / "pytorch_lora_weights.safetensors",
                latest / "unet" / "adapter_model.safetensors",
            ]:
                if candidate.exists():
                    src = candidate
                    print(f"Using checkpoint: {latest.name}")
                    break

    if not src.exists():
        raise FileNotFoundError(
            f"No weights found in {output_dir}/. "
            f"Check that training completed or a checkpoint was saved."
        )

    dst_dir = Path(styles_dir)
    dst_dir.mkdir(parents=True, exist_ok=True)
    dst = dst_dir / f"{style_name}.safetensors"

    import shutil
    shutil.copy2(src, dst)

    size_mb = dst.stat().st_size / (1024 * 1024)
    print(f"Exported weights: {dst} ({size_mb:.1f} MB)")
    print(f"Download this file and place it in your project's styles/ folder.")


# ---------------------------------------------------------------------------
# 6. Backup to Google Drive
# ---------------------------------------------------------------------------

def backup_to_drive(output_dir: str = OUTPUT_DIR):
    """Copy training output to Google Drive for safety.

    Note: If OUTPUT_DIR already points to Drive, this is a no-op.
    """
    drive_path = Path("/content/drive/MyDrive/lora_output")

    if Path(output_dir).resolve() == drive_path.resolve():
        print("Output already on Google Drive — no backup needed.")
        return

    if not Path("/content/drive/MyDrive").exists():
        from google.colab import drive
        drive.mount("/content/drive")

    import shutil
    shutil.copytree(output_dir, str(drive_path), dirs_exist_ok=True)
    print(f"Backed up to {drive_path}")


# ---------------------------------------------------------------------------
# 7. Test inference
# ---------------------------------------------------------------------------

def test_inference(
    output_dir: str = OUTPUT_DIR,
    prompt: str = None,
):
    """Generate a test image with the trained LoRA + Hyper-SD to verify quality.

    Uses the same setup as image_generator_hf.py for accurate results.
    """
    import torch
    from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline
    from huggingface_hub import hf_hub_download

    if prompt is None:
        prompt = f"a serene mountain landscape at golden hour, in {TRIGGER_WORD} style"

    print("Loading model + LoRA for test inference...")

    vae = AutoencoderKL.from_pretrained(
        CONFIG["vae"], torch_dtype=torch.float16,
    )

    pipe = DiffusionPipeline.from_pretrained(
        CONFIG["base_model"],
        vae=vae,
        torch_dtype=torch.float16,
        variant="fp16",
    ).to("cuda")

    # Load Hyper-SD (same as image_generator_hf.py)
    hyper_path = hf_hub_download(
        "ByteDance/Hyper-SD", "Hyper-SDXL-8steps-CFG-lora.safetensors",
    )
    pipe.load_lora_weights(hyper_path, adapter_name="hyper-sd")

    # Load trained style LoRA (check final weights, then latest checkpoint)
    output_path = Path(output_dir)
    weights_file = output_path / "pytorch_lora_weights.safetensors"
    if not weights_file.exists():
        checkpoints = sorted(
            output_path.glob("checkpoint-*"),
            key=lambda p: int(p.name.split("-")[1]),
        )
        if checkpoints:
            weights_file = checkpoints[-1] / "pytorch_lora_weights.safetensors"
    pipe.load_lora_weights(
        str(weights_file.parent),
        weight_name=weights_file.name,
        adapter_name="style",
    )

    pipe.set_adapters(
        ["hyper-sd", "style"],
        adapter_weights=[0.125, 1.0],
    )

    pipe.scheduler = DDIMScheduler.from_config(
        pipe.scheduler.config, timestep_spacing="trailing",
    )

    image = pipe(
        prompt=prompt,
        negative_prompt="blurry, low quality, deformed, ugly, text, watermark",
        num_inference_steps=8,
        guidance_scale=5.0,
        height=1344,
        width=768,
    ).images[0]

    image.save("test_output.png")
    print(f"Test image saved to test_output.png")
    print(f"Prompt: {prompt}")

    return image


# ---------------------------------------------------------------------------
# Main — run all steps in sequence
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    print("=" * 60)
    print("SDXL LoRA Training Pipeline")
    print("=" * 60)

    # Step 1: Install
    install_dependencies()

    # Step 2: Configure
    configure_accelerate()

    # Step 3: Verify dataset
    num_images = verify_dataset()
    steps = max(1500, num_images * 100)
    CONFIG["max_train_steps"] = steps
    print(f"Adjusted training steps to {steps} ({num_images} images × 100)")

    # Step 4: Train
    train()

    # Step 5: Backup
    backup_to_drive()

    # Step 6: Export
    export_weights(style_name="custom-style")

    # Step 7: Test
    test_inference()

    print("\nDone! Download styles/custom-style.safetensors")