AbstractPhil commited on
Commit
a978243
Β·
verified Β·
1 Parent(s): af9b8e5

Create trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +348 -0
trainer.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # Burn Test: ~44 images β†’ ~10k via multiplication, AR bucketed
3
+ # =============================================================================
4
+ # Prerequisites:
5
+ # !pip install -q torch torchvision safetensors transformers pillow
6
+ # !cd /content && git clone https://github.com/AbstractPhil/sd15-trainer-geo.git
7
+ # !cd /content/sd15-trainer-geo && pip install -e .
8
+ # Place burn_images_test.zip in /content/
9
+
10
+ import torch, gc, os, json, glob, zipfile, math, random, time
11
+ import numpy as np
12
+ from PIL import Image
13
+ from pathlib import Path
14
+ from collections import defaultdict
15
+ from torchvision import transforms
16
+
17
+ # =============================================================================
18
+ # 1 β€” Unzip + discover images and tags
19
+ # =============================================================================
20
+
21
+ ZIP_PATH = "/content/burn_images_test.zip"
22
+ EXTRACT = "/content/burn_images"
23
+ CACHE_DIR = "/content/latent_cache_burn"
24
+ TARGET = 10_000
25
+
26
+ os.makedirs(EXTRACT, exist_ok=True)
27
+ os.makedirs(CACHE_DIR, exist_ok=True)
28
+
29
+ with zipfile.ZipFile(ZIP_PATH, "r") as z:
30
+ z.extractall(EXTRACT)
31
+
32
+ IMG_EXTS = {".png", ".jpg", ".jpeg", ".webp", ".bmp"}
33
+ image_paths = sorted([
34
+ p for p in Path(EXTRACT).rglob("*")
35
+ if p.suffix.lower() in IMG_EXTS
36
+ ])
37
+ print(f"Found {len(image_paths)} images")
38
+
39
+ def find_tags(img_path: Path) -> str:
40
+ for ext in [".txt", ".caption", ".tags"]:
41
+ sidecar = img_path.with_suffix(ext)
42
+ if sidecar.exists():
43
+ return sidecar.read_text().strip()
44
+ cap_dir = img_path.parent / "captions" / (img_path.stem + ".txt")
45
+ if cap_dir.exists():
46
+ return cap_dir.read_text().strip()
47
+ return img_path.stem.replace("_", " ").replace("-", " ")
48
+
49
+ samples = []
50
+ for p in image_paths:
51
+ img = Image.open(p).convert("RGB")
52
+ w, h = img.size
53
+ samples.append({"path": p, "image": img, "w": w, "h": h, "tags": find_tags(p)})
54
+
55
+ print(f"\n── Image Inventory ({len(samples)} images) ──")
56
+ for s in samples:
57
+ print(f" {s['path'].name:40s} {s['w']:4d}Γ—{s['h']:4d} AR={s['w']/s['h']:.2f} {s['tags'][:60]}")
58
+
59
+ # =============================================================================
60
+ # 2 β€” AR bucketing
61
+ # =============================================================================
62
+
63
+ # Standard buckets at ~262k pixels, VAE-aligned (Γ·8)
64
+ BUCKETS = [
65
+ (512, 512), # 1:1
66
+ (576, 448), # landscape mild
67
+ (448, 576), # portrait mild
68
+ (640, 384), # landscape wide
69
+ (384, 640), # portrait tall
70
+ (704, 384), # landscape very wide
71
+ (384, 704), # portrait very tall
72
+ ]
73
+
74
+ def nearest_bucket(w, h):
75
+ ar = w / h
76
+ best, best_d = BUCKETS[0], 999
77
+ for bw, bh in BUCKETS:
78
+ d = abs(ar - bw / bh)
79
+ if d < best_d:
80
+ best_d, best = d, (bw, bh)
81
+ return best
82
+
83
+ bucket_groups = defaultdict(list)
84
+ for s in samples:
85
+ s["bucket"] = nearest_bucket(s["w"], s["h"])
86
+ bucket_groups[s["bucket"]].append(s)
87
+
88
+ print(f"\n── Bucket Assignment ──")
89
+ for (bw, bh), items in sorted(bucket_groups.items()):
90
+ print(f" {bw}Γ—{bh} ({bw/bh:.2f}): {len(items)} images")
91
+
92
+ # =============================================================================
93
+ # 3 β€” Encode latents per bucket (with multiplication)
94
+ # =============================================================================
95
+
96
+ from sd15_trainer_geo.pipeline import load_pipeline
97
+
98
+ pipe = load_pipeline(device="cuda", dtype=torch.float16)
99
+
100
+ n_images = len(samples)
101
+ repeats = max(1, TARGET // n_images)
102
+ actual_total = n_images * repeats
103
+ print(f"\n── Multiplication: {n_images} Γ— {repeats} = {actual_total} ──")
104
+
105
+ bucket_caches = {}
106
+
107
+ for (bw, bh), items in sorted(bucket_groups.items()):
108
+ n_bucket = len(items) * repeats
109
+ print(f"\n Encoding {bw}Γ—{bh}: {len(items)} unique β†’ {n_bucket} total")
110
+
111
+ # Resize: fit short edge to bucket, center crop to exact size
112
+ tfm = transforms.Compose([
113
+ transforms.Resize(max(bh, bw), interpolation=transforms.InterpolationMode.LANCZOS),
114
+ transforms.CenterCrop((bh, bw)),
115
+ transforms.ToTensor(),
116
+ transforms.Normalize([0.5], [0.5]),
117
+ ])
118
+
119
+ all_latents, all_enc_hs = [], []
120
+
121
+ for s in items:
122
+ img_t = tfm(s["image"]).unsqueeze(0).to(pipe.device, pipe.dtype)
123
+ with torch.no_grad():
124
+ lat = pipe.encode_image(img_t, sample=True)
125
+ ehs = pipe.encode_prompts([s["tags"]])
126
+ all_latents.extend([lat.cpu()] * repeats)
127
+ all_enc_hs.extend([ehs.cpu()] * repeats)
128
+
129
+ latents = torch.cat(all_latents, dim=0)
130
+ enc_hs = torch.cat(all_enc_hs, dim=0)
131
+
132
+ cache_path = os.path.join(CACHE_DIR, f"burn_{bw}x{bh}.pt")
133
+ torch.save({"latents": latents, "encoder_hidden_states": enc_hs}, cache_path)
134
+ bucket_caches[(bw, bh)] = {"path": cache_path, "count": len(latents)}
135
+ print(f" βœ“ {len(latents)} β†’ {cache_path} (latent {latents.shape})")
136
+
137
+ # Free encoder models
138
+ del pipe
139
+ gc.collect(); torch.cuda.empty_cache()
140
+
141
+ # =============================================================================
142
+ # 4 β€” Reload pipeline + Lune UNet
143
+ # =============================================================================
144
+
145
+ from sd15_trainer_geo.pipeline import load_pipeline
146
+ from sd15_trainer_geo.generate import generate, save_images, show_images
147
+
148
+ pipe = load_pipeline(device="cuda", dtype=torch.float16)
149
+ pipe.unet.load_pretrained(
150
+ "AbstractPhil/tinyflux-experts", subfolder="",
151
+ filename="sd15-flow-lune-unet.safetensors",
152
+ )
153
+
154
+ sample_tags = [s["tags"] for s in samples[:4]]
155
+ print(f"\n── Sample prompts ──")
156
+ for t in sample_tags:
157
+ print(f" {t[:80]}")
158
+
159
+ print("\n" + "=" * 60)
160
+ print("BASELINE (before training)")
161
+ print("=" * 60)
162
+ bl = generate(pipe, sample_tags, shift=2.5, seed=42, num_steps=30)
163
+ save_images(bl, "/content/samples_burn_baseline")
164
+ show_images(bl)
165
+
166
+ # =============================================================================
167
+ # 5 β€” Sequential bucket training (shared geo_prior weights)
168
+ # =============================================================================
169
+
170
+ from sd15_trainer_geo.trainer import Trainer, TrainConfig, LatentDataset
171
+ from sd15_trainer_geo.analyze import GeometryProfiler
172
+
173
+ TOTAL_STEPS = 10_000
174
+ total_samples = sum(v["count"] for v in bucket_caches.values())
175
+ sorted_buckets = sorted(bucket_caches.items(), key=lambda x: -x[1]["count"])
176
+
177
+ profiler = GeometryProfiler(pipe, every=100)
178
+ all_log_history = []
179
+ cumulative = 0
180
+
181
+ for (bw, bh), info in sorted_buckets:
182
+ steps = max(500, int(TOTAL_STEPS * info["count"] / total_samples))
183
+
184
+ print(f"\n{'='*60}")
185
+ print(f"TRAINING {bw}Γ—{bh}: {info['count']} samples, {steps} steps")
186
+ print(f"{'='*60}")
187
+
188
+ config = TrainConfig(
189
+ num_steps=steps,
190
+ batch_size=6,
191
+ base_lr=5e-5,
192
+ min_lr=1e-6,
193
+ lr_scheduler="cosine",
194
+ warmup_steps=min(200, steps // 5),
195
+ shift=2.5,
196
+ cfg_dropout=0.1,
197
+ min_snr_gamma=5.0,
198
+ geo_loss_weight=0.01,
199
+ geo_loss_warmup=min(400, steps // 3),
200
+ log_every=100,
201
+ sample_every=max(500, steps // 4),
202
+ save_every=max(500, steps // 4),
203
+ sample_prompts=sample_tags[:4],
204
+ seed=42,
205
+ output_dir=f"/content/geo_prior_burn/{bw}x{bh}",
206
+ )
207
+
208
+ ds = LatentDataset(info["path"])
209
+ trainer = Trainer(pipe, config)
210
+ trainer.fit(ds, callbacks=[profiler])
211
+
212
+ for entry in trainer.log_history:
213
+ entry["bucket"] = f"{bw}x{bh}"
214
+ entry["global_step"] = entry["step"] + cumulative
215
+ all_log_history.extend(trainer.log_history)
216
+ cumulative += steps
217
+
218
+ os.makedirs("/content/geo_prior_burn", exist_ok=True)
219
+ profiler.save("/content/geo_prior_burn/profiler.json")
220
+ with open("/content/geo_prior_burn/log_history.json", "w") as f:
221
+ json.dump(all_log_history, f, indent=2)
222
+
223
+ # =============================================================================
224
+ # 6 β€” Training analysis
225
+ # =============================================================================
226
+
227
+ from sd15_trainer_geo.analyze import analyze
228
+ summary = analyze(trainer, profiler, save_dir="/content/analysis_burn")
229
+
230
+ # =============================================================================
231
+ # 7 β€” Post-training analysis
232
+ # =============================================================================
233
+
234
+ from sd15_trainer_geo.analyze_post import PostTrainingAnalyzer
235
+ post = PostTrainingAnalyzer(pipe).run_all(save_dir="/content/post_analysis_burn")
236
+
237
+ # =============================================================================
238
+ # 8 β€” After-training samples
239
+ # =============================================================================
240
+
241
+ print("\n" + "=" * 60)
242
+ print("AFTER TRAINING β€” Same prompts")
243
+ print("=" * 60)
244
+ trained = generate(pipe, sample_tags, shift=2.5, seed=42, num_steps=30)
245
+ save_images(trained, "/content/samples_burn_trained")
246
+ show_images(trained)
247
+
248
+ # 1person anchor tests β€” the key diagnostic
249
+ anchor_prompts = [
250
+ "1person, good aesthetic, standing, full body",
251
+ "1person, very displeasing, portrait, close up",
252
+ "1person, good aesthetic, anime style, colorful background",
253
+ "1person, very displeasing, dark, moody lighting",
254
+ ]
255
+ print("\n" + "=" * 60)
256
+ print("ANCHOR TEST β€” 1person geometric routing")
257
+ print("=" * 60)
258
+ anchor = generate(pipe, anchor_prompts, shift=2.5, seed=42, num_steps=30)
259
+ save_images(anchor, "/content/samples_burn_anchor")
260
+ show_images(anchor)
261
+
262
+ # =============================================================================
263
+ # 9 β€” Push to hub
264
+ # =============================================================================
265
+
266
+ from sd15_trainer_geo.pipeline import push_geo_to_hub, save_geo_checkpoint
267
+ from huggingface_hub import HfApi
268
+
269
+ REPO = "AbstractPhil/sd15-geoflow-test-44"
270
+
271
+ save_geo_checkpoint(pipe, "/content/geo_prior_burn/geo_prior_final.pt")
272
+
273
+ push_geo_to_hub(
274
+ pipe, repo_id=REPO,
275
+ base_repo="sd-legacy/stable-diffusion-v1-5",
276
+ commit_message=f"burn test: {n_images} images Γ— {repeats} repeats, AR bucketed, {TOTAL_STEPS} steps",
277
+ extra={
278
+ "test_type": "burn_test",
279
+ "source_images": n_images,
280
+ "repeats": repeats,
281
+ "total_samples": actual_total,
282
+ "total_steps": TOTAL_STEPS,
283
+ "buckets": {f"{k[0]}x{k[1]}": v["count"] for k, v in bucket_caches.items()},
284
+ },
285
+ )
286
+
287
+ api = HfApi()
288
+
289
+ # Upload analysis artifacts
290
+ for pattern, prefix in [
291
+ ("/content/analysis_burn/*", "analysis"),
292
+ ("/content/post_analysis_burn/*", "post_analysis"),
293
+ ]:
294
+ for f in glob.glob(pattern):
295
+ if f.endswith((".png", ".json")):
296
+ api.upload_file(path_or_fileobj=f,
297
+ path_in_repo=f"{prefix}/{os.path.basename(f)}",
298
+ repo_id=REPO, repo_type="model")
299
+ print(f"βœ“ {prefix}/{os.path.basename(f)}")
300
+
301
+ # Upload profiler + logs
302
+ for f in ["/content/geo_prior_burn/profiler.json",
303
+ "/content/geo_prior_burn/log_history.json"]:
304
+ if os.path.exists(f):
305
+ api.upload_file(path_or_fileobj=f,
306
+ path_in_repo=f"analysis/{os.path.basename(f)}",
307
+ repo_id=REPO, repo_type="model")
308
+
309
+ # Upload bucket info
310
+ bucket_meta = {
311
+ "source_images": n_images,
312
+ "repeats": repeats,
313
+ "buckets": {f"{k[0]}x{k[1]}": v["count"] for k, v in bucket_caches.items()},
314
+ "tags": {s["path"].name: s["tags"] for s in samples},
315
+ }
316
+ meta_path = "/content/geo_prior_burn/bucket_info.json"
317
+ with open(meta_path, "w") as f:
318
+ json.dump(bucket_meta, f, indent=2)
319
+ api.upload_file(path_or_fileobj=meta_path,
320
+ path_in_repo="bucket_info.json",
321
+ repo_id=REPO, repo_type="model")
322
+
323
+ # Upload samples
324
+ for label, d in [("baseline", "/content/samples_burn_baseline"),
325
+ ("trained", "/content/samples_burn_trained"),
326
+ ("anchor", "/content/samples_burn_anchor")]:
327
+ if not os.path.exists(d): continue
328
+ for img in sorted(glob.glob(f"{d}/*.png")):
329
+ api.upload_file(path_or_fileobj=img,
330
+ path_in_repo=f"samples/{label}/{os.path.basename(img)}",
331
+ repo_id=REPO, repo_type="model")
332
+ print(f"βœ“ samples/{label}/")
333
+
334
+ # Training checkpoint samples
335
+ for (bw, bh), _ in sorted_buckets:
336
+ for img in glob.glob(f"/content/geo_prior_burn/{bw}x{bh}/samples/*.png"):
337
+ api.upload_file(path_or_fileobj=img,
338
+ path_in_repo=f"samples/training_{bw}x{bh}/{os.path.basename(img)}",
339
+ repo_id=REPO, repo_type="model")
340
+
341
+ # Source images for reference
342
+ for s in samples:
343
+ api.upload_file(path_or_fileobj=str(s["path"]),
344
+ path_in_repo=f"source_images/{s['path'].name}",
345
+ repo_id=REPO, repo_type="model")
346
+ print(f"βœ“ {len(samples)} source images")
347
+
348
+ print(f"\nhttps://huggingface.co/{REPO}")