AbstractPhil's picture
Create colab_trainer.py
121b617 verified
# =============================================================================
# SD15 Geo Prior Training β€” ImageNet-Synthetic (Schnell)
# Target: L4 (24GB VRAM)
# =============================================================================
# Cell 1: Install
# =============================================================================
# !pip install -q datasets transformers accelerate safetensors
# try:
# !pip uninstall -qy sd15-flow-trainer[dev]
# except:
# pass
#
# !pip install "sd15-flow-trainer[dev] @ git+https://github.com/AbstractEyes/sd15-flow-trainer.git" -q
# =============================================================================
# Cell 2: Pre-encode VAE + CLIP latents (cached to disk)
# =============================================================================
import torch
import os
CACHE_DIR = "/content/latent_cache"
CACHE_FILE = os.path.join(CACHE_DIR, "imagenet_synthetic_flux_10k.pt")
os.makedirs(CACHE_DIR, exist_ok=True)
if os.path.exists(CACHE_FILE):
print(f"βœ“ Cache exists: {CACHE_FILE}")
else:
from sd15_trainer_geo.pipeline import load_pipeline
from sd15_trainer_geo.trainer import pre_encode_hf_dataset
# Load pipeline with VAE + CLIP for encoding
pipe = load_pipeline(device="cuda", dtype=torch.float16)
pre_encode_hf_dataset(
pipe,
dataset_name="AbstractPhil/imagenet-synthetic",
subset="flux_schnell_512",
split="train",
image_column="image",
prompt_column="prompt",
output_path=CACHE_FILE,
image_size=512,
batch_size=16, # L4 handles 16 for encoding
)
# Free VAE + CLIP memory before training
del pipe
torch.cuda.empty_cache()
print("βœ“ Encoding complete, VRAM cleared")
# =============================================================================
# Cell 3: Load pipeline + Lune for training
# =============================================================================
from sd15_trainer_geo.pipeline import load_pipeline
from sd15_trainer_geo.trainer import TrainConfig, Trainer, LatentDataset
from sd15_trainer_geo.generate import generate, show_images, save_images
pipe = load_pipeline(device="cuda", dtype=torch.float16)
pipe.unet.load_pretrained(
repo_id="AbstractPhil/tinyflux-experts",
subfolder="",
filename="sd15-flow-lune-unet.safetensors",
)
# Verify Lune generates coherently before training
print("\n--- Pre-training baseline ---")
pre_out = generate(
pipe,
["a tabby cat on a windowsill",
"mountains at sunset, landscape painting",
"a bowl of ramen, studio photography",
"an astronaut riding a horse on mars"],
num_steps=25, cfg_scale=7.5, shift=2.5, seed=42,
)
save_images(pre_out, "/content/baseline_samples")
show_images(pre_out)
# =============================================================================
# Cell 4: Configure and train
# =============================================================================
dataset = LatentDataset(CACHE_FILE)
# 10k images / bs=6 = 1667 steps per epoch
# L4: bs=6 fits comfortably with frozen UNet fp16 + geo_prior fp32
config = TrainConfig(
# Core
num_steps=1667, # ~1 epoch
batch_size=6, # L4-safe with frozen backbone
base_lr=1e-4, # geo_prior only β€” higher than full UNet LR
weight_decay=0.01,
# Flow matching β€” match Lune
shift=2.5,
t_sample="logit_normal",
logit_normal_mean=0.0,
logit_normal_std=1.0,
t_min=0.001,
t_max=1.0,
# CFG dropout β€” critical for inference quality
cfg_dropout=0.1,
# Min-SNR β€” match Lune
min_snr_gamma=5.0,
# Geometric loss
geo_loss_weight=0.01,
geo_loss_warmup=200,
# LR schedule
lr_scheduler="cosine",
warmup_steps=100,
min_lr=1e-6,
# Mixed precision
use_amp=True,
grad_clip=1.0,
# Logging + sampling
log_every=50,
sample_every=500,
save_every=500,
sample_prompts=[
"a tabby cat sitting on a windowsill",
"mountains at sunset, landscape painting",
"a bowl of ramen, studio photography",
"an astronaut riding a horse on mars",
],
sample_steps=25,
sample_cfg=7.5,
# Output
output_dir="/content/geo_train_imagenet",
hub_repo_id=None, # Set to push checkpoints
# Data
num_workers=2,
pin_memory=True,
seed=42,
)
trainer = Trainer(pipe, config)
trainer.fit(dataset)
# =============================================================================
# Cell 5: Compare before/after
# =============================================================================
print("\n--- Post-training samples ---")
post_out = generate(
pipe,
["a tabby cat on a windowsill",
"mountains at sunset, landscape painting",
"a bowl of ramen, studio photography",
"an astronaut riding a horse on mars"],
num_steps=25, cfg_scale=7.5, shift=2.5, seed=42,
)
save_images(post_out, "/content/post_train_samples")
show_images(post_out)
# Also try prompts NOT in training set
print("\n--- Novel prompts (not in training set) ---")
novel_out = generate(
pipe,
["a cyberpunk cityscape at night with neon lights",
"a golden retriever playing in autumn leaves",
"a steampunk clocktower, detailed illustration",
"an underwater coral reef, macro photography"],
num_steps=25, cfg_scale=7.5, shift=2.5, seed=123,
)
save_images(novel_out, "/content/novel_samples")
show_images(novel_out)
# Print training summary
print(f"\nTraining: {len(trainer.log_history)} logged steps")
if trainer.log_history:
first = trainer.log_history[0]
last = trainer.log_history[-1]
print(f" Loss: {first['loss']:.4f} β†’ {last['loss']:.4f}")
print(f" Task: {first['task_loss']:.4f} β†’ {last['task_loss']:.4f}")
print(f" Geo: {first['geo_loss']:.6f} β†’ {last['geo_loss']:.6f}")
print(f" t_mean: {last.get('t_mean', 0):.3f} Β± {last.get('t_std', 0):.3f}")