AbstractPhil commited on
Commit
121b617
·
verified ·
1 Parent(s): ea637f7

Create colab_trainer.py

Browse files
Files changed (1) hide show
  1. colab_trainer.py +180 -0
colab_trainer.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # SD15 Geo Prior Training — ImageNet-Synthetic (Schnell)
3
+ # Target: L4 (24GB VRAM)
4
+ # =============================================================================
5
+ # Cell 1: Install
6
+ # =============================================================================
7
+ # !pip install -q datasets transformers accelerate safetensors
8
+ # try:
9
+ # !pip uninstall -qy sd15-flow-trainer[dev]
10
+ # except:
11
+ # pass
12
+ #
13
+ # !pip install "sd15-flow-trainer[dev] @ git+https://github.com/AbstractEyes/sd15-flow-trainer.git" -q
14
+ # =============================================================================
15
+ # Cell 2: Pre-encode VAE + CLIP latents (cached to disk)
16
+ # =============================================================================
17
+ import torch
18
+ import os
19
+
20
+ CACHE_DIR = "/content/latent_cache"
21
+ CACHE_FILE = os.path.join(CACHE_DIR, "imagenet_synthetic_flux_10k.pt")
22
+ os.makedirs(CACHE_DIR, exist_ok=True)
23
+
24
+ if os.path.exists(CACHE_FILE):
25
+ print(f"✓ Cache exists: {CACHE_FILE}")
26
+ else:
27
+ from sd15_trainer_geo.pipeline import load_pipeline
28
+ from sd15_trainer_geo.trainer import pre_encode_hf_dataset
29
+
30
+ # Load pipeline with VAE + CLIP for encoding
31
+ pipe = load_pipeline(device="cuda", dtype=torch.float16)
32
+
33
+ pre_encode_hf_dataset(
34
+ pipe,
35
+ dataset_name="AbstractPhil/imagenet-synthetic",
36
+ subset="flux_schnell_512",
37
+ split="train",
38
+ image_column="image",
39
+ prompt_column="prompt",
40
+ output_path=CACHE_FILE,
41
+ image_size=512,
42
+ batch_size=16, # L4 handles 16 for encoding
43
+ )
44
+
45
+ # Free VAE + CLIP memory before training
46
+ del pipe
47
+ torch.cuda.empty_cache()
48
+ print("✓ Encoding complete, VRAM cleared")
49
+
50
+ # =============================================================================
51
+ # Cell 3: Load pipeline + Lune for training
52
+ # =============================================================================
53
+ from sd15_trainer_geo.pipeline import load_pipeline
54
+ from sd15_trainer_geo.trainer import TrainConfig, Trainer, LatentDataset
55
+ from sd15_trainer_geo.generate import generate, show_images, save_images
56
+
57
+ pipe = load_pipeline(device="cuda", dtype=torch.float16)
58
+ pipe.unet.load_pretrained(
59
+ repo_id="AbstractPhil/tinyflux-experts",
60
+ subfolder="",
61
+ filename="sd15-flow-lune-unet.safetensors",
62
+ )
63
+
64
+ # Verify Lune generates coherently before training
65
+ print("\n--- Pre-training baseline ---")
66
+ pre_out = generate(
67
+ pipe,
68
+ ["a tabby cat on a windowsill",
69
+ "mountains at sunset, landscape painting",
70
+ "a bowl of ramen, studio photography",
71
+ "an astronaut riding a horse on mars"],
72
+ num_steps=25, cfg_scale=7.5, shift=2.5, seed=42,
73
+ )
74
+ save_images(pre_out, "/content/baseline_samples")
75
+ show_images(pre_out)
76
+
77
+ # =============================================================================
78
+ # Cell 4: Configure and train
79
+ # =============================================================================
80
+ dataset = LatentDataset(CACHE_FILE)
81
+
82
+ # 10k images / bs=6 = 1667 steps per epoch
83
+ # L4: bs=6 fits comfortably with frozen UNet fp16 + geo_prior fp32
84
+ config = TrainConfig(
85
+ # Core
86
+ num_steps=1667, # ~1 epoch
87
+ batch_size=6, # L4-safe with frozen backbone
88
+ base_lr=1e-4, # geo_prior only — higher than full UNet LR
89
+ weight_decay=0.01,
90
+
91
+ # Flow matching — match Lune
92
+ shift=2.5,
93
+ t_sample="logit_normal",
94
+ logit_normal_mean=0.0,
95
+ logit_normal_std=1.0,
96
+ t_min=0.001,
97
+ t_max=1.0,
98
+
99
+ # CFG dropout — critical for inference quality
100
+ cfg_dropout=0.1,
101
+
102
+ # Min-SNR — match Lune
103
+ min_snr_gamma=5.0,
104
+
105
+ # Geometric loss
106
+ geo_loss_weight=0.01,
107
+ geo_loss_warmup=200,
108
+
109
+ # LR schedule
110
+ lr_scheduler="cosine",
111
+ warmup_steps=100,
112
+ min_lr=1e-6,
113
+
114
+ # Mixed precision
115
+ use_amp=True,
116
+ grad_clip=1.0,
117
+
118
+ # Logging + sampling
119
+ log_every=50,
120
+ sample_every=500,
121
+ save_every=500,
122
+ sample_prompts=[
123
+ "a tabby cat sitting on a windowsill",
124
+ "mountains at sunset, landscape painting",
125
+ "a bowl of ramen, studio photography",
126
+ "an astronaut riding a horse on mars",
127
+ ],
128
+ sample_steps=25,
129
+ sample_cfg=7.5,
130
+
131
+ # Output
132
+ output_dir="/content/geo_train_imagenet",
133
+ hub_repo_id=None, # Set to push checkpoints
134
+
135
+ # Data
136
+ num_workers=2,
137
+ pin_memory=True,
138
+ seed=42,
139
+ )
140
+
141
+ trainer = Trainer(pipe, config)
142
+ trainer.fit(dataset)
143
+
144
+ # =============================================================================
145
+ # Cell 5: Compare before/after
146
+ # =============================================================================
147
+ print("\n--- Post-training samples ---")
148
+ post_out = generate(
149
+ pipe,
150
+ ["a tabby cat on a windowsill",
151
+ "mountains at sunset, landscape painting",
152
+ "a bowl of ramen, studio photography",
153
+ "an astronaut riding a horse on mars"],
154
+ num_steps=25, cfg_scale=7.5, shift=2.5, seed=42,
155
+ )
156
+ save_images(post_out, "/content/post_train_samples")
157
+ show_images(post_out)
158
+
159
+ # Also try prompts NOT in training set
160
+ print("\n--- Novel prompts (not in training set) ---")
161
+ novel_out = generate(
162
+ pipe,
163
+ ["a cyberpunk cityscape at night with neon lights",
164
+ "a golden retriever playing in autumn leaves",
165
+ "a steampunk clocktower, detailed illustration",
166
+ "an underwater coral reef, macro photography"],
167
+ num_steps=25, cfg_scale=7.5, shift=2.5, seed=123,
168
+ )
169
+ save_images(novel_out, "/content/novel_samples")
170
+ show_images(novel_out)
171
+
172
+ # Print training summary
173
+ print(f"\nTraining: {len(trainer.log_history)} logged steps")
174
+ if trainer.log_history:
175
+ first = trainer.log_history[0]
176
+ last = trainer.log_history[-1]
177
+ print(f" Loss: {first['loss']:.4f} → {last['loss']:.4f}")
178
+ print(f" Task: {first['task_loss']:.4f} → {last['task_loss']:.4f}")
179
+ print(f" Geo: {first['geo_loss']:.6f} → {last['geo_loss']:.6f}")
180
+ print(f" t_mean: {last.get('t_mean', 0):.3f} ± {last.get('t_std', 0):.3f}")