|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
try: |
|
|
!pip uninstall -qy sd15-flow-trainer[dev] |
|
|
except: |
|
|
pass |
|
|
|
|
|
!pip install "sd15-flow-trainer[dev] @ git+https://github.com/AbstractEyes/sd15-flow-trainer.git" -q |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import gc, os |
|
|
|
|
|
from sd15_trainer_geo.pipeline import load_pipeline |
|
|
pipe = load_pipeline(device="cuda", dtype=torch.float16) |
|
|
|
|
|
from sd15_trainer_geo.trainer import pre_encode_hf_dataset |
|
|
|
|
|
CACHE_PATH = "/content/latent_cache/object_relations_schnell_512_2.pt" |
|
|
|
|
|
pre_encode_hf_dataset( |
|
|
pipe, |
|
|
dataset_name="AbstractPhil/synthetic-object-relations", |
|
|
subset="schnell_512_2", |
|
|
split="train", |
|
|
image_column="image", |
|
|
prompt_column="prompt", |
|
|
output_path=CACHE_PATH, |
|
|
image_size=512, |
|
|
batch_size=16, |
|
|
max_samples=50_000, |
|
|
) |
|
|
|
|
|
del pipe.vae, pipe.clip |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
print(f"VRAM after encoding cleanup: {torch.cuda.memory_allocated()/1e9:.1f} GB") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from sd15_trainer_geo.pipeline import load_pipeline |
|
|
from sd15_trainer_geo.generate import generate, save_images, show_images |
|
|
|
|
|
pipe = load_pipeline(device="cuda", dtype=torch.float16) |
|
|
|
|
|
pipe.unet.load_pretrained( |
|
|
"AbstractPhil/tinyflux-experts", |
|
|
subfolder="", |
|
|
filename="sd15-flow-lune-unet.safetensors", |
|
|
) |
|
|
|
|
|
spatial_prompts = [ |
|
|
"a red cup on top of a blue book", |
|
|
"a cat sitting beside a vase of flowers", |
|
|
"a small ball inside a glass bowl on a table", |
|
|
"a pair of shoes next to an umbrella by the door", |
|
|
] |
|
|
|
|
|
novel_prompts = [ |
|
|
"a guitar leaning against a piano in a dim room", |
|
|
"three candles arranged in a triangle on a wooden tray", |
|
|
"a telescope pointed at the moon through an open window", |
|
|
"a child's drawing pinned to a refrigerator with magnets", |
|
|
] |
|
|
|
|
|
print("=" * 60) |
|
|
print("BASELINE (before geo_prior training)") |
|
|
print("=" * 60) |
|
|
|
|
|
baseline_spatial = generate(pipe, spatial_prompts, shift=2.5, seed=42, num_steps=30) |
|
|
save_images(baseline_spatial, "/content/samples_baseline_spatial") |
|
|
|
|
|
baseline_novel = generate(pipe, novel_prompts, shift=2.5, seed=42, num_steps=30) |
|
|
save_images(baseline_novel, "/content/samples_baseline_novel") |
|
|
|
|
|
show_images(baseline_spatial) |
|
|
show_images(baseline_novel) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from sd15_trainer_geo.trainer import Trainer, TrainConfig, LatentDataset |
|
|
|
|
|
config = TrainConfig( |
|
|
num_steps=8333, |
|
|
batch_size=6, |
|
|
base_lr=5e-5, |
|
|
min_lr=1e-6, |
|
|
lr_scheduler="cosine", |
|
|
warmup_steps=200, |
|
|
|
|
|
|
|
|
shift=2.5, |
|
|
cfg_dropout=0.1, |
|
|
min_snr_gamma=5.0, |
|
|
|
|
|
|
|
|
geo_loss_weight=0.01, |
|
|
geo_loss_warmup=400, |
|
|
|
|
|
|
|
|
log_every=100, |
|
|
sample_every=2000, |
|
|
save_every=2000, |
|
|
sample_prompts=spatial_prompts[:2] + novel_prompts[:2], |
|
|
seed=42, |
|
|
output_dir="/content/geo_prior_object_relations", |
|
|
) |
|
|
|
|
|
dataset = LatentDataset(CACHE_PATH) |
|
|
trainer = Trainer(pipe, config) |
|
|
trainer.fit(dataset) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from sd15_trainer_geo.pipeline import push_geo_to_hub |
|
|
|
|
|
push_geo_to_hub( |
|
|
pipe, |
|
|
repo_id="AbstractPhil/sd15-geoflow-object-association", |
|
|
base_repo="sd-legacy/stable-diffusion-v1-5", |
|
|
commit_message="geo_prior v1: 1 epoch 50k object-relations schnell_512_2", |
|
|
extra={ |
|
|
"dataset": "AbstractPhil/synthetic-object-relations (schnell_512_2)", |
|
|
"samples": 50000, |
|
|
"epochs": 1, |
|
|
"steps": 8333, |
|
|
"shift": 2.5, |
|
|
"base_lr": 5e-5, |
|
|
"min_snr_gamma": 5.0, |
|
|
"cfg_dropout": 0.1, |
|
|
"batch_size": 6, |
|
|
"geo_loss_weight": 0.01, |
|
|
"loss_final": trainer.log_history[-1]["loss"] if trainer.log_history else "n/a", |
|
|
}, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("=" * 60) |
|
|
print("AFTER TRAINING — Spatial Prompts (in-distribution)") |
|
|
print("=" * 60) |
|
|
trained_spatial = generate(pipe, spatial_prompts, shift=2.5, seed=42, num_steps=30) |
|
|
save_images(trained_spatial, "/content/samples_trained_spatial") |
|
|
show_images(trained_spatial) |
|
|
|
|
|
print("=" * 60) |
|
|
print("AFTER TRAINING — Novel Prompts (out-of-distribution)") |
|
|
print("=" * 60) |
|
|
trained_novel = generate(pipe, novel_prompts, shift=2.5, seed=42, num_steps=30) |
|
|
save_images(trained_novel, "/content/samples_trained_novel") |
|
|
show_images(trained_novel) |
|
|
|
|
|
hard_spatial = [ |
|
|
"a book on top of a cup", |
|
|
"a lamp beneath a table", |
|
|
"a knife to the left of a fork on a plate", |
|
|
"a hat resting on a basketball", |
|
|
"a key inside a shoe next to the door", |
|
|
"a red apple behind a green bottle", |
|
|
] |
|
|
print("=" * 60) |
|
|
print("HARD SPATIAL (never seen, complex relations)") |
|
|
print("=" * 60) |
|
|
hard_out = generate(pipe, hard_spatial, shift=2.5, seed=42, num_steps=30) |
|
|
save_images(hard_out, "/content/samples_hard_spatial") |
|
|
show_images(hard_out) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("TRAINING SUMMARY") |
|
|
print("=" * 60) |
|
|
|
|
|
if trainer.log_history: |
|
|
first = trainer.log_history[0] |
|
|
last = trainer.log_history[-1] |
|
|
mid = trainer.log_history[len(trainer.log_history) // 2] |
|
|
|
|
|
print(f"Steps: {last.get('step', config.num_steps)}") |
|
|
print(f"Loss (start): {first['loss']:.4f}") |
|
|
print(f"Loss (mid): {mid['loss']:.4f}") |
|
|
print(f"Loss (final): {last['loss']:.4f}") |
|
|
print(f"Task (final): {last.get('task_loss', 'n/a')}") |
|
|
print(f"Geo (final): {last.get('geo_loss', 'n/a')}") |
|
|
|
|
|
stats = pipe.unet.get_geometry_stats() |
|
|
if stats: |
|
|
print(f"\nGeometry:") |
|
|
print(f" Blend: {stats.get('blend', 'n/a')}") |
|
|
for i in range(4): |
|
|
vol = stats.get(f'layer_{i}/vol_sq', 'n/a') |
|
|
ent = stats.get(f'layer_{i}/entropy', 'n/a') |
|
|
ds = stats.get(f'layer_{i}/deform_scale', 'n/a') |
|
|
if isinstance(vol, float): |
|
|
print(f" Layer {i}: vol²={vol:.4e}, entropy={ent:.2f}, δ={ds:.4f}") |
|
|
|
|
|
print(f"\nCheckpoints: /content/geo_prior_object_relations/") |
|
|
print(f"Hub: https://huggingface.co/AbstractPhil/sd15-geoflow-object-association") |