File size: 5,961 Bytes
121b617
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
# =============================================================================
# SD15 Geo Prior Training β€” ImageNet-Synthetic (Schnell)
# Target: L4 (24GB VRAM)
# =============================================================================
# Cell 1: Install
# =============================================================================
# !pip install -q datasets transformers accelerate safetensors
# try:
#   !pip uninstall -qy sd15-flow-trainer[dev]
# except:
#   pass
# 
# !pip install "sd15-flow-trainer[dev] @ git+https://github.com/AbstractEyes/sd15-flow-trainer.git" -q
# =============================================================================
# Cell 2: Pre-encode VAE + CLIP latents (cached to disk)
# =============================================================================
import torch
import os

CACHE_DIR = "/content/latent_cache"
CACHE_FILE = os.path.join(CACHE_DIR, "imagenet_synthetic_flux_10k.pt")
os.makedirs(CACHE_DIR, exist_ok=True)

if os.path.exists(CACHE_FILE):
    print(f"βœ“ Cache exists: {CACHE_FILE}")
else:
    from sd15_trainer_geo.pipeline import load_pipeline
    from sd15_trainer_geo.trainer import pre_encode_hf_dataset

    # Load pipeline with VAE + CLIP for encoding
    pipe = load_pipeline(device="cuda", dtype=torch.float16)

    pre_encode_hf_dataset(
        pipe,
        dataset_name="AbstractPhil/imagenet-synthetic",
        subset="flux_schnell_512",
        split="train",
        image_column="image",
        prompt_column="prompt",
        output_path=CACHE_FILE,
        image_size=512,
        batch_size=16,        # L4 handles 16 for encoding
    )

    # Free VAE + CLIP memory before training
    del pipe
    torch.cuda.empty_cache()
    print("βœ“ Encoding complete, VRAM cleared")

# =============================================================================
# Cell 3: Load pipeline + Lune for training
# =============================================================================
from sd15_trainer_geo.pipeline import load_pipeline
from sd15_trainer_geo.trainer import TrainConfig, Trainer, LatentDataset
from sd15_trainer_geo.generate import generate, show_images, save_images

pipe = load_pipeline(device="cuda", dtype=torch.float16)
pipe.unet.load_pretrained(
    repo_id="AbstractPhil/tinyflux-experts",
    subfolder="",
    filename="sd15-flow-lune-unet.safetensors",
)

# Verify Lune generates coherently before training
print("\n--- Pre-training baseline ---")
pre_out = generate(
    pipe,
    ["a tabby cat on a windowsill",
     "mountains at sunset, landscape painting",
     "a bowl of ramen, studio photography",
     "an astronaut riding a horse on mars"],
    num_steps=25, cfg_scale=7.5, shift=2.5, seed=42,
)
save_images(pre_out, "/content/baseline_samples")
show_images(pre_out)

# =============================================================================
# Cell 4: Configure and train
# =============================================================================
dataset = LatentDataset(CACHE_FILE)

# 10k images / bs=6 = 1667 steps per epoch
# L4: bs=6 fits comfortably with frozen UNet fp16 + geo_prior fp32
config = TrainConfig(
    # Core
    num_steps=1667,           # ~1 epoch
    batch_size=6,             # L4-safe with frozen backbone
    base_lr=1e-4,             # geo_prior only β€” higher than full UNet LR
    weight_decay=0.01,

    # Flow matching β€” match Lune
    shift=2.5,
    t_sample="logit_normal",
    logit_normal_mean=0.0,
    logit_normal_std=1.0,
    t_min=0.001,
    t_max=1.0,

    # CFG dropout β€” critical for inference quality
    cfg_dropout=0.1,

    # Min-SNR β€” match Lune
    min_snr_gamma=5.0,

    # Geometric loss
    geo_loss_weight=0.01,
    geo_loss_warmup=200,

    # LR schedule
    lr_scheduler="cosine",
    warmup_steps=100,
    min_lr=1e-6,

    # Mixed precision
    use_amp=True,
    grad_clip=1.0,

    # Logging + sampling
    log_every=50,
    sample_every=500,
    save_every=500,
    sample_prompts=[
        "a tabby cat sitting on a windowsill",
        "mountains at sunset, landscape painting",
        "a bowl of ramen, studio photography",
        "an astronaut riding a horse on mars",
    ],
    sample_steps=25,
    sample_cfg=7.5,

    # Output
    output_dir="/content/geo_train_imagenet",
    hub_repo_id=None,         # Set to push checkpoints

    # Data
    num_workers=2,
    pin_memory=True,
    seed=42,
)

trainer = Trainer(pipe, config)
trainer.fit(dataset)

# =============================================================================
# Cell 5: Compare before/after
# =============================================================================
print("\n--- Post-training samples ---")
post_out = generate(
    pipe,
    ["a tabby cat on a windowsill",
     "mountains at sunset, landscape painting",
     "a bowl of ramen, studio photography",
     "an astronaut riding a horse on mars"],
    num_steps=25, cfg_scale=7.5, shift=2.5, seed=42,
)
save_images(post_out, "/content/post_train_samples")
show_images(post_out)

# Also try prompts NOT in training set
print("\n--- Novel prompts (not in training set) ---")
novel_out = generate(
    pipe,
    ["a cyberpunk cityscape at night with neon lights",
     "a golden retriever playing in autumn leaves",
     "a steampunk clocktower, detailed illustration",
     "an underwater coral reef, macro photography"],
    num_steps=25, cfg_scale=7.5, shift=2.5, seed=123,
)
save_images(novel_out, "/content/novel_samples")
show_images(novel_out)

# Print training summary
print(f"\nTraining: {len(trainer.log_history)} logged steps")
if trainer.log_history:
    first = trainer.log_history[0]
    last = trainer.log_history[-1]
    print(f"  Loss: {first['loss']:.4f} β†’ {last['loss']:.4f}")
    print(f"  Task: {first['task_loss']:.4f} β†’ {last['task_loss']:.4f}")
    print(f"  Geo:  {first['geo_loss']:.6f} β†’ {last['geo_loss']:.6f}")
    print(f"  t_mean: {last.get('t_mean', 0):.3f} Β± {last.get('t_std', 0):.3f}")