File size: 5,961 Bytes
121b617 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
# =============================================================================
# SD15 Geo Prior Training β ImageNet-Synthetic (Schnell)
# Target: L4 (24GB VRAM)
# =============================================================================
# Cell 1: Install
# =============================================================================
# !pip install -q datasets transformers accelerate safetensors
# try:
# !pip uninstall -qy sd15-flow-trainer[dev]
# except:
# pass
#
# !pip install "sd15-flow-trainer[dev] @ git+https://github.com/AbstractEyes/sd15-flow-trainer.git" -q
# =============================================================================
# Cell 2: Pre-encode VAE + CLIP latents (cached to disk)
# =============================================================================
import torch
import os
CACHE_DIR = "/content/latent_cache"
CACHE_FILE = os.path.join(CACHE_DIR, "imagenet_synthetic_flux_10k.pt")
os.makedirs(CACHE_DIR, exist_ok=True)
if os.path.exists(CACHE_FILE):
print(f"β Cache exists: {CACHE_FILE}")
else:
from sd15_trainer_geo.pipeline import load_pipeline
from sd15_trainer_geo.trainer import pre_encode_hf_dataset
# Load pipeline with VAE + CLIP for encoding
pipe = load_pipeline(device="cuda", dtype=torch.float16)
pre_encode_hf_dataset(
pipe,
dataset_name="AbstractPhil/imagenet-synthetic",
subset="flux_schnell_512",
split="train",
image_column="image",
prompt_column="prompt",
output_path=CACHE_FILE,
image_size=512,
batch_size=16, # L4 handles 16 for encoding
)
# Free VAE + CLIP memory before training
del pipe
torch.cuda.empty_cache()
print("β Encoding complete, VRAM cleared")
# =============================================================================
# Cell 3: Load pipeline + Lune for training
# =============================================================================
from sd15_trainer_geo.pipeline import load_pipeline
from sd15_trainer_geo.trainer import TrainConfig, Trainer, LatentDataset
from sd15_trainer_geo.generate import generate, show_images, save_images
pipe = load_pipeline(device="cuda", dtype=torch.float16)
pipe.unet.load_pretrained(
repo_id="AbstractPhil/tinyflux-experts",
subfolder="",
filename="sd15-flow-lune-unet.safetensors",
)
# Verify Lune generates coherently before training
print("\n--- Pre-training baseline ---")
pre_out = generate(
pipe,
["a tabby cat on a windowsill",
"mountains at sunset, landscape painting",
"a bowl of ramen, studio photography",
"an astronaut riding a horse on mars"],
num_steps=25, cfg_scale=7.5, shift=2.5, seed=42,
)
save_images(pre_out, "/content/baseline_samples")
show_images(pre_out)
# =============================================================================
# Cell 4: Configure and train
# =============================================================================
dataset = LatentDataset(CACHE_FILE)
# 10k images / bs=6 = 1667 steps per epoch
# L4: bs=6 fits comfortably with frozen UNet fp16 + geo_prior fp32
config = TrainConfig(
# Core
num_steps=1667, # ~1 epoch
batch_size=6, # L4-safe with frozen backbone
base_lr=1e-4, # geo_prior only β higher than full UNet LR
weight_decay=0.01,
# Flow matching β match Lune
shift=2.5,
t_sample="logit_normal",
logit_normal_mean=0.0,
logit_normal_std=1.0,
t_min=0.001,
t_max=1.0,
# CFG dropout β critical for inference quality
cfg_dropout=0.1,
# Min-SNR β match Lune
min_snr_gamma=5.0,
# Geometric loss
geo_loss_weight=0.01,
geo_loss_warmup=200,
# LR schedule
lr_scheduler="cosine",
warmup_steps=100,
min_lr=1e-6,
# Mixed precision
use_amp=True,
grad_clip=1.0,
# Logging + sampling
log_every=50,
sample_every=500,
save_every=500,
sample_prompts=[
"a tabby cat sitting on a windowsill",
"mountains at sunset, landscape painting",
"a bowl of ramen, studio photography",
"an astronaut riding a horse on mars",
],
sample_steps=25,
sample_cfg=7.5,
# Output
output_dir="/content/geo_train_imagenet",
hub_repo_id=None, # Set to push checkpoints
# Data
num_workers=2,
pin_memory=True,
seed=42,
)
trainer = Trainer(pipe, config)
trainer.fit(dataset)
# =============================================================================
# Cell 5: Compare before/after
# =============================================================================
print("\n--- Post-training samples ---")
post_out = generate(
pipe,
["a tabby cat on a windowsill",
"mountains at sunset, landscape painting",
"a bowl of ramen, studio photography",
"an astronaut riding a horse on mars"],
num_steps=25, cfg_scale=7.5, shift=2.5, seed=42,
)
save_images(post_out, "/content/post_train_samples")
show_images(post_out)
# Also try prompts NOT in training set
print("\n--- Novel prompts (not in training set) ---")
novel_out = generate(
pipe,
["a cyberpunk cityscape at night with neon lights",
"a golden retriever playing in autumn leaves",
"a steampunk clocktower, detailed illustration",
"an underwater coral reef, macro photography"],
num_steps=25, cfg_scale=7.5, shift=2.5, seed=123,
)
save_images(novel_out, "/content/novel_samples")
show_images(novel_out)
# Print training summary
print(f"\nTraining: {len(trainer.log_history)} logged steps")
if trainer.log_history:
first = trainer.log_history[0]
last = trainer.log_history[-1]
print(f" Loss: {first['loss']:.4f} β {last['loss']:.4f}")
print(f" Task: {first['task_loss']:.4f} β {last['task_loss']:.4f}")
print(f" Geo: {first['geo_loss']:.6f} β {last['geo_loss']:.6f}")
print(f" t_mean: {last.get('t_mean', 0):.3f} Β± {last.get('t_std', 0):.3f}") |