|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import os |
|
|
|
|
|
CACHE_DIR = "/content/latent_cache" |
|
|
CACHE_FILE = os.path.join(CACHE_DIR, "imagenet_synthetic_flux_10k.pt") |
|
|
os.makedirs(CACHE_DIR, exist_ok=True) |
|
|
|
|
|
if os.path.exists(CACHE_FILE): |
|
|
print(f"β Cache exists: {CACHE_FILE}") |
|
|
else: |
|
|
from sd15_trainer_geo.pipeline import load_pipeline |
|
|
from sd15_trainer_geo.trainer import pre_encode_hf_dataset |
|
|
|
|
|
|
|
|
pipe = load_pipeline(device="cuda", dtype=torch.float16) |
|
|
|
|
|
pre_encode_hf_dataset( |
|
|
pipe, |
|
|
dataset_name="AbstractPhil/imagenet-synthetic", |
|
|
subset="flux_schnell_512", |
|
|
split="train", |
|
|
image_column="image", |
|
|
prompt_column="prompt", |
|
|
output_path=CACHE_FILE, |
|
|
image_size=512, |
|
|
batch_size=16, |
|
|
) |
|
|
|
|
|
|
|
|
del pipe |
|
|
torch.cuda.empty_cache() |
|
|
print("β Encoding complete, VRAM cleared") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from sd15_trainer_geo.pipeline import load_pipeline |
|
|
from sd15_trainer_geo.trainer import TrainConfig, Trainer, LatentDataset |
|
|
from sd15_trainer_geo.generate import generate, show_images, save_images |
|
|
|
|
|
pipe = load_pipeline(device="cuda", dtype=torch.float16) |
|
|
pipe.unet.load_pretrained( |
|
|
repo_id="AbstractPhil/tinyflux-experts", |
|
|
subfolder="", |
|
|
filename="sd15-flow-lune-unet.safetensors", |
|
|
) |
|
|
|
|
|
|
|
|
print("\n--- Pre-training baseline ---") |
|
|
pre_out = generate( |
|
|
pipe, |
|
|
["a tabby cat on a windowsill", |
|
|
"mountains at sunset, landscape painting", |
|
|
"a bowl of ramen, studio photography", |
|
|
"an astronaut riding a horse on mars"], |
|
|
num_steps=25, cfg_scale=7.5, shift=2.5, seed=42, |
|
|
) |
|
|
save_images(pre_out, "/content/baseline_samples") |
|
|
show_images(pre_out) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset = LatentDataset(CACHE_FILE) |
|
|
|
|
|
|
|
|
|
|
|
config = TrainConfig( |
|
|
|
|
|
num_steps=1667, |
|
|
batch_size=6, |
|
|
base_lr=1e-4, |
|
|
weight_decay=0.01, |
|
|
|
|
|
|
|
|
shift=2.5, |
|
|
t_sample="logit_normal", |
|
|
logit_normal_mean=0.0, |
|
|
logit_normal_std=1.0, |
|
|
t_min=0.001, |
|
|
t_max=1.0, |
|
|
|
|
|
|
|
|
cfg_dropout=0.1, |
|
|
|
|
|
|
|
|
min_snr_gamma=5.0, |
|
|
|
|
|
|
|
|
geo_loss_weight=0.01, |
|
|
geo_loss_warmup=200, |
|
|
|
|
|
|
|
|
lr_scheduler="cosine", |
|
|
warmup_steps=100, |
|
|
min_lr=1e-6, |
|
|
|
|
|
|
|
|
use_amp=True, |
|
|
grad_clip=1.0, |
|
|
|
|
|
|
|
|
log_every=50, |
|
|
sample_every=500, |
|
|
save_every=500, |
|
|
sample_prompts=[ |
|
|
"a tabby cat sitting on a windowsill", |
|
|
"mountains at sunset, landscape painting", |
|
|
"a bowl of ramen, studio photography", |
|
|
"an astronaut riding a horse on mars", |
|
|
], |
|
|
sample_steps=25, |
|
|
sample_cfg=7.5, |
|
|
|
|
|
|
|
|
output_dir="/content/geo_train_imagenet", |
|
|
hub_repo_id=None, |
|
|
|
|
|
|
|
|
num_workers=2, |
|
|
pin_memory=True, |
|
|
seed=42, |
|
|
) |
|
|
|
|
|
trainer = Trainer(pipe, config) |
|
|
trainer.fit(dataset) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n--- Post-training samples ---") |
|
|
post_out = generate( |
|
|
pipe, |
|
|
["a tabby cat on a windowsill", |
|
|
"mountains at sunset, landscape painting", |
|
|
"a bowl of ramen, studio photography", |
|
|
"an astronaut riding a horse on mars"], |
|
|
num_steps=25, cfg_scale=7.5, shift=2.5, seed=42, |
|
|
) |
|
|
save_images(post_out, "/content/post_train_samples") |
|
|
show_images(post_out) |
|
|
|
|
|
|
|
|
print("\n--- Novel prompts (not in training set) ---") |
|
|
novel_out = generate( |
|
|
pipe, |
|
|
["a cyberpunk cityscape at night with neon lights", |
|
|
"a golden retriever playing in autumn leaves", |
|
|
"a steampunk clocktower, detailed illustration", |
|
|
"an underwater coral reef, macro photography"], |
|
|
num_steps=25, cfg_scale=7.5, shift=2.5, seed=123, |
|
|
) |
|
|
save_images(novel_out, "/content/novel_samples") |
|
|
show_images(novel_out) |
|
|
|
|
|
|
|
|
print(f"\nTraining: {len(trainer.log_history)} logged steps") |
|
|
if trainer.log_history: |
|
|
first = trainer.log_history[0] |
|
|
last = trainer.log_history[-1] |
|
|
print(f" Loss: {first['loss']:.4f} β {last['loss']:.4f}") |
|
|
print(f" Task: {first['task_loss']:.4f} β {last['task_loss']:.4f}") |
|
|
print(f" Geo: {first['geo_loss']:.6f} β {last['geo_loss']:.6f}") |
|
|
print(f" t_mean: {last.get('t_mean', 0):.3f} Β± {last.get('t_std', 0):.3f}") |