AbstractPhil's picture
Update trainer.py
bbc3095 verified
# =============================================================================
# Cell 1: Install dependencies
# =============================================================================
"""
try:
!pip uninstall -qy sd15-flow-trainer[dev]
except:
pass
!pip install "sd15-flow-trainer[dev] @ git+https://github.com/AbstractEyes/sd15-flow-trainer.git" -q
"""
# =============================================================================
# Cell 2: Pre-encode 50k object-relations dataset
# =============================================================================
import torch
import gc, os
from sd15_trainer_geo.pipeline import load_pipeline
pipe = load_pipeline(device="cuda", dtype=torch.float16)
from sd15_trainer_geo.trainer import pre_encode_hf_dataset
CACHE_PATH = "/content/latent_cache/object_relations_schnell_512_2.pt"
pre_encode_hf_dataset(
pipe,
dataset_name="AbstractPhil/synthetic-object-relations",
subset="schnell_512_2",
split="train",
image_column="image",
prompt_column="prompt",
output_path=CACHE_PATH,
image_size=512,
batch_size=16,
max_samples=50_000,
)
del pipe.vae, pipe.clip
gc.collect()
torch.cuda.empty_cache()
print(f"VRAM after encoding cleanup: {torch.cuda.memory_allocated()/1e9:.1f} GB")
# =============================================================================
# Cell 3: Load pipeline + Lune UNet, baseline samples
# =============================================================================
from sd15_trainer_geo.pipeline import load_pipeline
from sd15_trainer_geo.generate import generate, save_images, show_images
pipe = load_pipeline(device="cuda", dtype=torch.float16)
pipe.unet.load_pretrained(
"AbstractPhil/tinyflux-experts",
subfolder="",
filename="sd15-flow-lune-unet.safetensors",
)
spatial_prompts = [
"a red cup on top of a blue book",
"a cat sitting beside a vase of flowers",
"a small ball inside a glass bowl on a table",
"a pair of shoes next to an umbrella by the door",
]
novel_prompts = [
"a guitar leaning against a piano in a dim room",
"three candles arranged in a triangle on a wooden tray",
"a telescope pointed at the moon through an open window",
"a child's drawing pinned to a refrigerator with magnets",
]
print("=" * 60)
print("BASELINE (before geo_prior training)")
print("=" * 60)
baseline_spatial = generate(pipe, spatial_prompts, shift=2.5, seed=42, num_steps=30)
save_images(baseline_spatial, "/content/samples_baseline_spatial")
baseline_novel = generate(pipe, novel_prompts, shift=2.5, seed=42, num_steps=30)
save_images(baseline_novel, "/content/samples_baseline_novel")
show_images(baseline_spatial)
show_images(baseline_novel)
# =============================================================================
# Cell 4: Train geo_prior on 50k object-relations
# =============================================================================
from sd15_trainer_geo.trainer import Trainer, TrainConfig, LatentDataset
config = TrainConfig(
num_steps=8333,
batch_size=6,
base_lr=5e-5,
min_lr=1e-6,
lr_scheduler="cosine",
warmup_steps=200,
# Flow matching
shift=2.5,
cfg_dropout=0.1,
min_snr_gamma=5.0,
# Geometric regularization
geo_loss_weight=0.01,
geo_loss_warmup=400,
# Logging
log_every=100,
sample_every=2000,
save_every=2000,
sample_prompts=spatial_prompts[:2] + novel_prompts[:2],
seed=42,
output_dir="/content/geo_prior_object_relations",
)
dataset = LatentDataset(CACHE_PATH)
trainer = Trainer(pipe, config)
trainer.fit(dataset)
# =============================================================================
# Cell 5: Push trained weights to hub
# =============================================================================
from sd15_trainer_geo.pipeline import push_geo_to_hub
push_geo_to_hub(
pipe,
repo_id="AbstractPhil/sd15-geoflow-object-association",
base_repo="sd-legacy/stable-diffusion-v1-5",
commit_message="geo_prior v1: 1 epoch 50k object-relations schnell_512_2",
extra={
"dataset": "AbstractPhil/synthetic-object-relations (schnell_512_2)",
"samples": 50000,
"epochs": 1,
"steps": 8333,
"shift": 2.5,
"base_lr": 5e-5,
"min_snr_gamma": 5.0,
"cfg_dropout": 0.1,
"batch_size": 6,
"geo_loss_weight": 0.01,
"loss_final": trainer.log_history[-1]["loss"] if trainer.log_history else "n/a",
},
)
# =============================================================================
# Cell 6: Compare before/after
# =============================================================================
print("=" * 60)
print("AFTER TRAINING — Spatial Prompts (in-distribution)")
print("=" * 60)
trained_spatial = generate(pipe, spatial_prompts, shift=2.5, seed=42, num_steps=30)
save_images(trained_spatial, "/content/samples_trained_spatial")
show_images(trained_spatial)
print("=" * 60)
print("AFTER TRAINING — Novel Prompts (out-of-distribution)")
print("=" * 60)
trained_novel = generate(pipe, novel_prompts, shift=2.5, seed=42, num_steps=30)
save_images(trained_novel, "/content/samples_trained_novel")
show_images(trained_novel)
hard_spatial = [
"a book on top of a cup",
"a lamp beneath a table",
"a knife to the left of a fork on a plate",
"a hat resting on a basketball",
"a key inside a shoe next to the door",
"a red apple behind a green bottle",
]
print("=" * 60)
print("HARD SPATIAL (never seen, complex relations)")
print("=" * 60)
hard_out = generate(pipe, hard_spatial, shift=2.5, seed=42, num_steps=30)
save_images(hard_out, "/content/samples_hard_spatial")
show_images(hard_out)
# =============================================================================
# Cell 7: Training summary
# =============================================================================
print("\n" + "=" * 60)
print("TRAINING SUMMARY")
print("=" * 60)
if trainer.log_history:
first = trainer.log_history[0]
last = trainer.log_history[-1]
mid = trainer.log_history[len(trainer.log_history) // 2]
print(f"Steps: {last.get('step', config.num_steps)}")
print(f"Loss (start): {first['loss']:.4f}")
print(f"Loss (mid): {mid['loss']:.4f}")
print(f"Loss (final): {last['loss']:.4f}")
print(f"Task (final): {last.get('task_loss', 'n/a')}")
print(f"Geo (final): {last.get('geo_loss', 'n/a')}")
stats = pipe.unet.get_geometry_stats()
if stats:
print(f"\nGeometry:")
print(f" Blend: {stats.get('blend', 'n/a')}")
for i in range(4):
vol = stats.get(f'layer_{i}/vol_sq', 'n/a')
ent = stats.get(f'layer_{i}/entropy', 'n/a')
ds = stats.get(f'layer_{i}/deform_scale', 'n/a')
if isinstance(vol, float):
print(f" Layer {i}: vol²={vol:.4e}, entropy={ent:.2f}, δ={ds:.4f}")
print(f"\nCheckpoints: /content/geo_prior_object_relations/")
print(f"Hub: https://huggingface.co/AbstractPhil/sd15-geoflow-object-association")