AbstractPhil commited on
Commit
64736e9
·
verified ·
1 Parent(s): 3d7d199

Create classifier_trainer.py

Browse files
Files changed (1) hide show
  1. classifier_trainer.py +476 -0
classifier_trainer.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # CELL 3: Train Geometric Classifier + Upload to HuggingFace
3
+ # Requires: Cell 1 (generator/constants), Cell 2 (model classes)
4
+ # Outputs: `model` in notebook scope + geometric_classifier/ on HF
5
+ #
6
+ # Features:
7
+ # - Dataset cached to disk (skip regeneration on resume)
8
+ # - Checkpoint saved every epoch (model, optimizer, scheduler, epoch, best_acc)
9
+ # - Auto-resume from latest checkpoint
10
+ # =============================================================================
11
+
12
+ import json, time, os, shutil
13
+ from pathlib import Path
14
+
15
+ HF_REPO = "AbstractPhil/grid-geometric-classifier-proto"
16
+ CKPT_DIR = Path("./checkpoints")
17
+ DATASET_PATH = Path("./cached_dataset.pt")
18
+
19
+ # --- Loss Functions ---
20
+
21
+ def _safe_bce(inp, tgt):
22
+ """BCE that forces fp32 and clamps to prevent log(0) from BF16 sigmoid saturation."""
23
+ with torch.amp.autocast('cuda', enabled=False):
24
+ return F.binary_cross_entropy(
25
+ inp.float().clamp(1e-7, 1 - 1e-7),
26
+ tgt.float())
27
+
28
+ def capacity_fill_loss(fr, dt): return _safe_bce(fr, dt)
29
+ def overflow_reg(on, dt):
30
+ """Vectorized overflow penalty — no Python loops, no .item() calls."""
31
+ pk = dt.sum(dim=-1).long().clamp(min=0) # (B,) peak dim index
32
+ n_caps = on.shape[1]
33
+ arange = torch.arange(n_caps, device=on.device).unsqueeze(0) # (1, n_caps)
34
+ mask = (arange >= pk.unsqueeze(1)).float() # (B, n_caps)
35
+ return (on * mask).sum() / (on.shape[0] + 1e-8)
36
+ def cap_diversity(c): return -c.var()
37
+ def peak_loss(l, t): return F.cross_entropy(l, t)
38
+ def cm_loss(p, t): return F.mse_loss(p, torch.sign(t))
39
+ def curved_bce(p, t): return _safe_bce(p.squeeze(-1), t)
40
+ def ctype_loss(l, t): return F.cross_entropy(l, t)
41
+
42
+ # --- Dataset Cache ---
43
+
44
+ def get_or_generate_dataset(n_samples, seed, path=DATASET_PATH):
45
+ """Load cached dataset from disk, or generate + cache it."""
46
+ if path.exists():
47
+ print(f"Loading cached dataset from {path}...")
48
+ t0 = time.time()
49
+ cached = torch.load(path, weights_only=True)
50
+ if cached["n_samples"] == n_samples and cached["seed"] == seed:
51
+ train_ds = ShapeDataset.__new__(ShapeDataset)
52
+ val_ds = ShapeDataset.__new__(ShapeDataset)
53
+ for k in ["grids", "labels", "dim_conf", "peak_dim", "volume", "cm_det", "is_curved", "curvature"]:
54
+ setattr(train_ds, k, cached["train"][k])
55
+ setattr(val_ds, k, cached["val"][k])
56
+ dt = time.time() - t0
57
+ print(f"Loaded {len(train_ds)} train + {len(val_ds)} val in {dt:.1f}s (cached)")
58
+ return train_ds, val_ds
59
+ else:
60
+ print(f"Cache mismatch (n={cached['n_samples']}, seed={cached['seed']}) — regenerating")
61
+
62
+ all_samples = generate_parallel(n_samples, seed=seed, n_workers=8)
63
+ n_train = int(len(all_samples) * 0.8)
64
+ train_ds = ShapeDataset(all_samples[:n_train])
65
+ val_ds = ShapeDataset(all_samples[n_train:])
66
+
67
+ print(f"Caching dataset to {path}...")
68
+ cache_data = {
69
+ "n_samples": n_samples, "seed": seed,
70
+ "train": {k: getattr(train_ds, k) for k in ["grids", "labels", "dim_conf", "peak_dim", "volume", "cm_det", "is_curved", "curvature"]},
71
+ "val": {k: getattr(val_ds, k) for k in ["grids", "labels", "dim_conf", "peak_dim", "volume", "cm_det", "is_curved", "curvature"]},
72
+ }
73
+ torch.save(cache_data, path)
74
+ size_mb = path.stat().st_size / 1e6
75
+ print(f"Cached: {size_mb:.0f}MB")
76
+ return train_ds, val_ds
77
+
78
+ # --- Checkpoint helpers ---
79
+
80
+ def save_checkpoint(model, optimizer, scheduler, epoch, best_val_acc, ckpt_dir=CKPT_DIR):
81
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
82
+ raw = model._orig_mod if hasattr(model, '_orig_mod') else model
83
+ ckpt = {
84
+ "epoch": epoch,
85
+ "best_val_acc": best_val_acc,
86
+ "model_state_dict": raw.state_dict(),
87
+ "optimizer_state_dict": optimizer.state_dict(),
88
+ "scheduler_state_dict": scheduler.state_dict(),
89
+ }
90
+ path = ckpt_dir / f"epoch_{epoch:03d}.pt"
91
+ torch.save(ckpt, path)
92
+ latest = ckpt_dir / "latest.pt"
93
+ torch.save(ckpt, latest)
94
+ return path
95
+
96
+
97
+ def load_checkpoint(model, optimizer, scheduler, ckpt_dir=CKPT_DIR):
98
+ latest = ckpt_dir / "latest.pt"
99
+ if not latest.exists():
100
+ return 0, 0.0
101
+ print(f"Resuming from {latest}...")
102
+ ckpt = torch.load(latest, weights_only=False)
103
+ raw = model._orig_mod if hasattr(model, '_orig_mod') else model
104
+ raw.load_state_dict(ckpt["model_state_dict"])
105
+ optimizer.load_state_dict(ckpt["optimizer_state_dict"])
106
+ scheduler.load_state_dict(ckpt["scheduler_state_dict"])
107
+ start_epoch = ckpt["epoch"] + 1
108
+ best_val_acc = ckpt["best_val_acc"]
109
+ print(f"Resumed: epoch {start_epoch}, best_val_acc={best_val_acc:.4f}")
110
+ return start_epoch, best_val_acc
111
+
112
+
113
+ # --- Training ---
114
+
115
+ def train(n_samples=500000, epochs=80, batch_size=4096, lr=3e-3, seed=42):
116
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
117
+ print(f"Device: {device}")
118
+
119
+ if device.type == "cuda":
120
+ torch.backends.cuda.matmul.allow_tf32 = True
121
+ torch.backends.cudnn.allow_tf32 = True
122
+ if hasattr(torch.backends.cuda.matmul, 'fp32_precision'):
123
+ torch.backends.cuda.matmul.fp32_precision = 'tf32'
124
+ if hasattr(torch.backends.cudnn, 'conv') and hasattr(torch.backends.cudnn.conv, 'fp32_precision'):
125
+ torch.backends.cudnn.conv.fp32_precision = 'tf32'
126
+ torch.backends.cudnn.benchmark = True
127
+ props = torch.cuda.get_device_properties(0)
128
+ print(f"GPU: {props.name} | {props.total_memory / 1e9:.1f}GB | SM {props.major}.{props.minor}")
129
+ print(f"TF32: enabled | cuDNN benchmark: enabled | batch: {batch_size}")
130
+
131
+ train_ds, val_ds = get_or_generate_dataset(n_samples, seed)
132
+ print(f"Train: {len(train_ds)} | Val: {len(val_ds)} | {NUM_CLASSES} classes | pre-tensored")
133
+
134
+ train_loader = torch.utils.data.DataLoader(
135
+ train_ds, batch_size=batch_size, shuffle=True,
136
+ num_workers=4, pin_memory=True, persistent_workers=True)
137
+ val_loader = torch.utils.data.DataLoader(
138
+ val_ds, batch_size=batch_size, shuffle=False,
139
+ num_workers=4, pin_memory=True, persistent_workers=True)
140
+
141
+ model = GeometricShapeClassifier().to(device)
142
+ n_params = sum(p.numel() for p in model.parameters())
143
+ print(f"Model: {n_params:,} parameters")
144
+ if device.type == "cuda":
145
+ print(f"VRAM after model load: {torch.cuda.memory_allocated()/1e9:.2f}GB / "
146
+ f"{torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")
147
+
148
+ use_amp = device.type == "cuda"
149
+ amp_dtype = torch.bfloat16 if (device.type == "cuda" and
150
+ torch.cuda.is_bf16_supported()) else torch.float16
151
+ use_scaler = use_amp and amp_dtype == torch.float16
152
+ scaler = torch.amp.GradScaler('cuda', enabled=use_scaler)
153
+
154
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
155
+ warmup_epochs = 5
156
+ def lr_lambda(epoch):
157
+ if epoch < warmup_epochs:
158
+ return (epoch + 1) / warmup_epochs
159
+ return 0.5 * (1 + math.cos(math.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
160
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
161
+
162
+ # Resume from checkpoint (loads into model BEFORE compile)
163
+ start_epoch, best_val_acc = load_checkpoint(model, optimizer, scheduler)
164
+
165
+ # Compile AFTER loading checkpoint weights
166
+ if device.type == "cuda" and hasattr(torch, 'compile'):
167
+ try:
168
+ model = torch.compile(model, mode="default")
169
+ print("torch.compile: enabled (default mode)")
170
+ except Exception as e:
171
+ print(f"torch.compile: skipped ({e})")
172
+
173
+ print(f"AMP: {'bf16' if amp_dtype == torch.bfloat16 else 'fp16'}" +
174
+ (f" (scaler: {'on' if use_scaler else 'off'})" if use_amp else " disabled"))
175
+
176
+ w = {"cls": 1.0, "fill": 0.3, "peak": 0.3, "ovf": 0.05,
177
+ "div": 0.02, "vol": 0.1, "cm": 0.1, "curved": 0.2, "ctype": 0.2,
178
+ "arb_cls": 0.8, "arb_traj": 0.2, "arb_conf": 0.1, "flow": 0.5}
179
+
180
+ epoch_start = time.time()
181
+
182
+ for epoch in range(start_epoch, epochs):
183
+ t0 = time.time()
184
+ model.train()
185
+ correct, total = 0, 0
186
+ correct_init, correct_ref = 0, 0
187
+
188
+ for batch_idx, (grid, label, dc, pd, vol, cm, ic, ct) in enumerate(train_loader):
189
+ grid, label = grid.to(device, non_blocking=True), label.to(device, non_blocking=True)
190
+ dc, pd = dc.to(device, non_blocking=True), pd.to(device, non_blocking=True)
191
+ vol, cm = vol.to(device, non_blocking=True), cm.to(device, non_blocking=True)
192
+ ic, ct = ic.to(device, non_blocking=True), ct.to(device, non_blocking=True)
193
+
194
+ grid = deform_grid(grid, p_dropout=0.05, p_add=0.05, p_shift=0.08)
195
+ optimizer.zero_grad(set_to_none=True)
196
+
197
+ try:
198
+ with torch.amp.autocast('cuda', enabled=use_amp, dtype=amp_dtype):
199
+ out = model(grid, labels=label)
200
+
201
+ loss_first = (w["cls"] * F.cross_entropy(out["initial_logits"], label) +
202
+ w["fill"] * capacity_fill_loss(out["fill_ratios"], dc) +
203
+ w["peak"] * peak_loss(out["peak_logits"], pd) +
204
+ w["ovf"] * overflow_reg(out["overflows"], dc) +
205
+ w["div"] * cap_diversity(out["capacities"]) +
206
+ w["vol"] * F.mse_loss(out["volume_pred"], torch.log1p(vol)) +
207
+ w["cm"] * cm_loss(out["cm_pred"], cm) +
208
+ w["curved"] * curved_bce(out["is_curved_pred"], ic) +
209
+ w["ctype"] * ctype_loss(out["curv_type_logits"], ct))
210
+
211
+ loss_arb = w["arb_cls"] * F.cross_entropy(out["refined_logits"], label)
212
+ traj_loss = 0
213
+ for step_i, step_logits in enumerate(out["trajectory_logits"]):
214
+ step_weight = (step_i + 1) / len(out["trajectory_logits"])
215
+ traj_loss += step_weight * F.cross_entropy(step_logits, label)
216
+ traj_loss /= len(out["trajectory_logits"])
217
+ loss_arb += w["arb_traj"] * traj_loss
218
+ loss_arb += w["flow"] * out["flow_loss"]
219
+
220
+ with torch.no_grad():
221
+ is_correct = (out["refined_logits"].argmax(1) == label).float()
222
+ loss_arb += w["arb_conf"] * _safe_bce(
223
+ out["refined_confidence"].squeeze(-1), is_correct)
224
+
225
+ with torch.no_grad():
226
+ init_correct = (out["initial_logits"].argmax(1) == label).float()
227
+ ref_correct = (out["refined_logits"].argmax(1) == label).float()
228
+ blend_target = torch.where(init_correct >= ref_correct,
229
+ torch.ones_like(init_correct) * 0.8,
230
+ torch.ones_like(init_correct) * 0.2)
231
+ loss_arb += 0.1 * _safe_bce(out["blend_weight"], blend_target)
232
+
233
+ loss_blend = w["cls"] * F.cross_entropy(out["class_logits"], label)
234
+ loss = loss_first + loss_arb + loss_blend
235
+
236
+ # NaN guard: skip batch if loss is non-finite
237
+ if not torch.isfinite(loss).item():
238
+ optimizer.zero_grad(set_to_none=True)
239
+ total += grid.size(0)
240
+ continue
241
+ scaler.scale(loss).backward()
242
+ scaler.unscale_(optimizer)
243
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
244
+ scaler.step(optimizer)
245
+ scaler.update()
246
+
247
+ except RuntimeError as e:
248
+ if "CUDA" in str(e) or "device-side" in str(e):
249
+ print(f"\n!!! CUDA error at epoch {epoch+1}, batch {batch_idx} !!!")
250
+ print(f" Error: {e}")
251
+ print(f" label range: [{label.min().item()}, {label.max().item()}]")
252
+ print(f" pd range: [{pd.min().item()}, {pd.max().item()}]")
253
+ print(f" ct range: [{ct.min().item()}, {ct.max().item()}]")
254
+ print(f" Checkpoint saved at epoch {epoch-1}")
255
+ print(f" To diagnose: add os.environ['CUDA_LAUNCH_BLOCKING']='1' before training")
256
+ raise
257
+
258
+ correct += (out["class_logits"].argmax(1) == label).sum().item()
259
+ correct_init += (out["initial_logits"].argmax(1) == label).sum().item()
260
+ correct_ref += (out["refined_logits"].argmax(1) == label).sum().item()
261
+ total += grid.size(0)
262
+
263
+ scheduler.step()
264
+ train_acc = correct / total
265
+
266
+ if epoch == start_epoch and device.type == "cuda":
267
+ peak = torch.cuda.max_memory_allocated() / 1e9
268
+ print(f"VRAM peak: {peak:.2f}GB | throughput: {total/(time.time()-t0):.0f} samples/s")
269
+
270
+ model.eval()
271
+ vc, vt, vcc, vct = 0, 0, 0, 0
272
+ vc_init, vc_ref = 0, 0
273
+ val_fills, val_alts, val_confs, val_blends = [], [], [], []
274
+
275
+ with torch.no_grad(), torch.amp.autocast('cuda', enabled=use_amp, dtype=amp_dtype):
276
+ for grid, label, dc, pd, vol, cm, ic, ct in val_loader:
277
+ grid, label = grid.to(device, non_blocking=True), label.to(device, non_blocking=True)
278
+ ic = ic.to(device, non_blocking=True)
279
+ out = model(grid)
280
+ vc += (out["class_logits"].argmax(1) == label).sum().item()
281
+ vc_init += (out["initial_logits"].argmax(1) == label).sum().item()
282
+ vc_ref += (out["refined_logits"].argmax(1) == label).sum().item()
283
+ vt += grid.size(0)
284
+ vcc += ((out["is_curved_pred"].squeeze(-1) > 0.5).float() == ic).sum().item()
285
+ vct += grid.size(0)
286
+ val_fills.append(out["fill_ratios"].cpu())
287
+ val_alts.append(out["alternation"].cpu())
288
+ val_confs.append(out["confidence"].cpu())
289
+ val_blends.append(out["blend_weight"].cpu())
290
+
291
+ val_acc = vc / vt; val_init = vc_init / vt; val_ref = vc_ref / vt
292
+ curved_acc = vcc / vct
293
+ mf = torch.cat(val_fills).mean(dim=0)
294
+ mc = torch.cat(val_confs).mean().item()
295
+ mb = torch.cat(val_blends).mean().item()
296
+ marker = " *" if val_acc > best_val_acc else ""
297
+ if val_acc > best_val_acc: best_val_acc = val_acc
298
+
299
+ raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model
300
+ with torch.no_grad():
301
+ caps = [F.softplus(getattr(raw_model, f"dim{d}")._raw_capacity).item() for d in range(4)]
302
+
303
+ dt = time.time() - t0
304
+
305
+ # Save checkpoint every epoch
306
+ save_checkpoint(model, optimizer, scheduler, epoch, best_val_acc)
307
+
308
+ if (epoch + 1) % 10 == 0 or epoch == start_epoch or marker:
309
+ if (epoch + 1) % 10 == 0 or epoch == start_epoch:
310
+ print(f"Epoch {epoch+1:3d}/{epochs} [{dt:.1f}s {total/dt:.0f} s/s] | "
311
+ f"blend {val_acc:.3f} init {val_init:.3f} arb {val_ref:.3f} | "
312
+ f"conf {mc:.3f} blend_w {mb:.2f} | curved {curved_acc:.3f} | "
313
+ f"fill [{mf[0]:.2f} {mf[1]:.2f} {mf[2]:.2f} {mf[3]:.2f}] | "
314
+ f"cap [{caps[0]:.2f} {caps[1]:.2f} {caps[2]:.2f} {caps[3]:.2f}]{marker}")
315
+ elif marker:
316
+ print(f"Epoch {epoch+1:3d}/{epochs} [{dt:.1f}s] | "
317
+ f"blend {val_acc:.3f} init {val_init:.3f} arb {val_ref:.3f} | conf {mc:.3f}{marker}")
318
+
319
+ total_time = time.time() - epoch_start
320
+ print(f"\nTraining complete in {total_time:.0f}s ({total_time/60:.1f}min)")
321
+ print(f"Best val accuracy: {best_val_acc:.4f}")
322
+ raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model
323
+
324
+ # --- Per-class breakdown ---
325
+ cc_b, cc_i, cc_r = {n: 0 for n in CLASS_NAMES}, {n: 0 for n in CLASS_NAMES}, {n: 0 for n in CLASS_NAMES}
326
+ ct_c = {n: 0 for n in CLASS_NAMES}
327
+ cf = {n: [] for n in CLASS_NAMES}
328
+ cconf = {n: [] for n in CLASS_NAMES}
329
+ cblend = {n: [] for n in CLASS_NAMES}
330
+
331
+ with torch.no_grad(), torch.amp.autocast('cuda', enabled=use_amp, dtype=amp_dtype):
332
+ for grid, label, *_ in val_loader:
333
+ grid, label = grid.to(device, non_blocking=True), label.to(device, non_blocking=True)
334
+ out = model(grid)
335
+ pb = out["class_logits"].argmax(1)
336
+ pi = out["initial_logits"].argmax(1)
337
+ pr = out["refined_logits"].argmax(1)
338
+ for k in range(len(label)):
339
+ name = CLASS_NAMES[label[k].item()]
340
+ cc_b[name] += (pb[k] == label[k]).item()
341
+ cc_i[name] += (pi[k] == label[k]).item()
342
+ cc_r[name] += (pr[k] == label[k]).item()
343
+ ct_c[name] += 1
344
+ cf[name].append(out["fill_ratios"][k].cpu().numpy())
345
+ cconf[name].append(out["confidence"][k].item())
346
+ cblend[name].append(out["blend_weight"][k].item())
347
+
348
+ print(f"\n{'Class':22s} | {'Blend':>5s} {'Init':>5s} {'Arb':>5s} | "
349
+ f"{'Conf':>5s} {'Bld':>4s} | {'Corr':>4s}/{'Tot':>4s} | "
350
+ f"{'Fill Ratios':22s} | {'Type':8s} Curvature")
351
+ print("-" * 110)
352
+ for name in CLASS_NAMES:
353
+ if ct_c[name] == 0: continue
354
+ ab = cc_b[name]/ct_c[name]; ai = cc_i[name]/ct_c[name]; ar = cc_r[name]/ct_c[name]
355
+ mfv = np.mean(cf[name], axis=0)
356
+ mconf = np.mean(cconf[name]); mblend = np.mean(cblend[name])
357
+ info = SHAPE_CATALOG[name]
358
+ arb_flag = f" +{ar-ai:+.3f}" if ar-ai > 0.01 else ""
359
+ print(f" {name:20s} | {ab:.3f} {ai:.3f} {ar:.3f} | "
360
+ f"{mconf:.3f} {mblend:.2f} | {cc_b[name]:4d}/{ct_c[name]:4d} | "
361
+ f"[{mfv[0]:.2f} {mfv[1]:.2f} {mfv[2]:.2f} {mfv[3]:.2f}] | "
362
+ f"{'CURVED' if info['curved'] else 'rigid':8s} {info['curvature']}{arb_flag}")
363
+
364
+ print(f"\n--- Arbiter Impact Summary ---")
365
+ imps = [(n, cc_i[n]/ct_c[n], cc_r[n]/ct_c[n], cc_b[n]/ct_c[n], cc_r[n]/ct_c[n]-cc_i[n]/ct_c[n])
366
+ for n in CLASS_NAMES if ct_c[n] > 0]
367
+ imps.sort(key=lambda x: x[4], reverse=True)
368
+ print(f" {'Class':20s} | {'Init':>5s} {'Arb':>5s} {'Blend':>5s} | {'Δ':>6s}")
369
+ for name, ai, ar, ab, delta in imps[:10]:
370
+ print(f" {name:20s} | {ai:.3f} {ar:.3f} {ab:.3f} | {delta:+.3f}")
371
+
372
+ # =========================================================================
373
+ # Upload geometric_classifier/ to HuggingFace
374
+ # =========================================================================
375
+ print("\n" + "=" * 70)
376
+ print("Saving geometric_classifier/ to HuggingFace")
377
+ print("=" * 70)
378
+
379
+ raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model
380
+ staging = Path("./hf_staging/geometric_classifier")
381
+ staging.mkdir(parents=True, exist_ok=True)
382
+
383
+ arch_config = {
384
+ "model_type": "GeometricShapeClassifier",
385
+ "version": "v8",
386
+ "grid_size": GS,
387
+ "num_classes": NUM_CLASSES,
388
+ "class_names": CLASS_NAMES,
389
+ "curvature_types": CURVATURE_TYPES,
390
+ "embed_dim": 128,
391
+ "n_tracers": 5,
392
+ "capacity_dims": [64, 64, 64, 64],
393
+ "curvature_embed_dim": 128,
394
+ "arbiter_latent_dim": 128,
395
+ "arbiter_flow_steps": 4,
396
+ "total_params": sum(p.numel() for p in raw_model.parameters()),
397
+ "shape_catalog": {k: v for k, v in SHAPE_CATALOG.items()},
398
+ }
399
+ with open(staging / "config.json", "w") as f:
400
+ json.dump(arch_config, f, indent=2)
401
+
402
+ train_cfg = {
403
+ "n_samples": n_samples, "epochs": epochs, "batch_size": batch_size,
404
+ "lr": lr, "seed": seed, "optimizer": "AdamW", "weight_decay": 1e-4,
405
+ "scheduler": "cosine_with_warmup", "warmup_epochs": warmup_epochs,
406
+ "amp_dtype": str(amp_dtype), "loss_weights": w,
407
+ "best_val_accuracy": best_val_acc, "learned_capacities": caps,
408
+ "total_training_time_seconds": total_time,
409
+ }
410
+ with open(staging / "training_config.json", "w") as f:
411
+ json.dump(train_cfg, f, indent=2)
412
+
413
+ try:
414
+ from safetensors.torch import save_file as st_save
415
+ st_save(raw_model.state_dict(), str(staging / "model.safetensors"))
416
+ print(f" Saved: model.safetensors")
417
+ except ImportError:
418
+ torch.save(raw_model.state_dict(), staging / "model.pt")
419
+ print(f" Saved: model.pt (install safetensors for .safetensors)")
420
+
421
+ try:
422
+ from huggingface_hub import HfApi, create_repo
423
+ token = None
424
+ try:
425
+ from google.colab import userdata
426
+ token = userdata.get('HF_TOKEN')
427
+ except Exception:
428
+ token = os.environ.get('HF_TOKEN')
429
+
430
+ if token:
431
+ api = HfApi(token=token)
432
+ create_repo(HF_REPO, token=token, exist_ok=True)
433
+ readme = Path("./hf_staging/README.md")
434
+ readme.write_text(f"""---
435
+ license: mit
436
+ tags:
437
+ - geometric-deep-learning
438
+ - voxel-classifier
439
+ - cross-contrast
440
+ - pentachoron
441
+ ---
442
+
443
+ # Grid Geometric Classifier Proto
444
+
445
+ Geometric primitive classifier using 5x5x5 binary voxel grids with capacity cascade,
446
+ curvature analysis, differentiation gates, and rectified flow arbiter.
447
+
448
+ ## Structure
449
+
450
+ ```
451
+ geometric_classifier/ # Voxel classifier (v8, ~1.85M params)
452
+ crosscontrast/ # Text-Voxel alignment heads
453
+ qwen_embeddings/ # Cached Qwen 2.5-1.5B embeddings
454
+ ```
455
+
456
+ ## 38 Shape Classes
457
+
458
+ Rigid 0D-3D: point, lines, triangles, quads, tetrahedra, cubes, prisms, octahedra, pentachoron
459
+ Curved 1D-3D: arcs, helices, circles, ellipses, discs, spheres, hemispheres, cylinders, cones, capsules, tori, shells, tubes, bowls, saddles
460
+ """)
461
+ api.upload_file(path_or_fileobj=str(readme), path_in_repo="README.md",
462
+ repo_id=HF_REPO, token=token, commit_message="README")
463
+ api.upload_folder(
464
+ folder_path=str(staging), repo_id=HF_REPO,
465
+ path_in_repo="geometric_classifier", token=token,
466
+ commit_message=f"geometric_classifier v8 | acc={best_val_acc:.4f} | {sum(p.numel() for p in raw_model.parameters()):,} params")
467
+ print(f"Uploaded: https://huggingface.co/{HF_REPO}/tree/main/geometric_classifier")
468
+ else:
469
+ print("No HF_TOKEN — saved locally at ./hf_staging/geometric_classifier/")
470
+ except Exception as e:
471
+ print(f"HF upload failed: {e}\n Weights at ./hf_staging/geometric_classifier/")
472
+
473
+ return model
474
+
475
+
476
+ model = train()