AbstractPhil commited on
Commit
14993d7
·
verified ·
1 Parent(s): 58eb211

Create mlp_ablation.py

Browse files
Files changed (1) hide show
  1. mlp_ablation.py +433 -0
mlp_ablation.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run cell 1 and cell 2, shape factory and model then run this to continue.
2
+ # ablation showed random chance without the full geometric architecture.
3
+
4
+ # =============================================================================
5
+ # CELL 5: Architecture Ablation — MLP Baseline with Same Loss
6
+ # Requires: Cell 1 + Cell 2 already executed (constants, generator, deform_grid)
7
+ # Question: Is the loss creating the behavior, or the architecture?
8
+ #
9
+ # Same composite loss, same data, same hyperparams.
10
+ # Plain MLP replaces: tracer attention, capacity cascade,
11
+ # differentiation gate, curvature head, rectified flow arbiter.
12
+ # =============================================================================
13
+
14
+ import math, time, numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from pathlib import Path
19
+
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ if device.type == "cuda":
22
+ torch.backends.cuda.matmul.allow_tf32 = True
23
+ torch.backends.cudnn.allow_tf32 = True
24
+ torch.backends.cudnn.benchmark = True
25
+
26
+ use_amp = device.type == "cuda"
27
+ amp_dtype = (torch.bfloat16 if (device.type == "cuda" and
28
+ torch.cuda.is_bf16_supported()) else torch.float16)
29
+
30
+
31
+ # =============================================================================
32
+ # MLP Baseline — same output contract as GeometricShapeClassifier
33
+ # =============================================================================
34
+
35
+ class MLPBaseline(nn.Module):
36
+ """Plain MLP producing the same output dict as GeometricShapeClassifier.
37
+ No geometric inductive bias. Same loss surface."""
38
+
39
+ def __init__(self, grid_size=GS, n_classes=NUM_CLASSES,
40
+ n_curvatures=NUM_CURVATURES, trunk_dim=256):
41
+ super().__init__()
42
+ inp = grid_size ** 3 # 125
43
+
44
+ self.trunk = nn.Sequential(
45
+ nn.Linear(inp, 512), nn.GELU(),
46
+ nn.Linear(512, 512), nn.GELU(),
47
+ nn.Linear(512, trunk_dim), nn.GELU(),
48
+ nn.Linear(trunk_dim, trunk_dim), nn.GELU(),
49
+ )
50
+
51
+ # Primary classifier
52
+ self.classifier = nn.Sequential(
53
+ nn.Linear(trunk_dim, 128), nn.GELU(), nn.Dropout(0.1),
54
+ nn.Linear(128, n_classes))
55
+
56
+ # Capacity analog: fill ratios (4 dims, sigmoid)
57
+ self.fill_head = nn.Sequential(
58
+ nn.Linear(trunk_dim, 64), nn.GELU(),
59
+ nn.Linear(64, 4), nn.Sigmoid())
60
+
61
+ # Learned capacities for diversity loss
62
+ self.cap_head = nn.Sequential(
63
+ nn.Linear(trunk_dim, 32), nn.GELU(),
64
+ nn.Linear(32, 4), nn.Softplus())
65
+
66
+ # Peak dimension (4-class)
67
+ self.peak_head = nn.Sequential(
68
+ nn.Linear(trunk_dim, 32), nn.GELU(), nn.Linear(32, 4))
69
+
70
+ # Overflow (4 dims, sigmoid)
71
+ self.overflow_head = nn.Sequential(
72
+ nn.Linear(trunk_dim, 32), nn.GELU(),
73
+ nn.Linear(32, 4), nn.Sigmoid())
74
+
75
+ # Volume regression
76
+ self.volume_head = nn.Sequential(
77
+ nn.Linear(trunk_dim, 64), nn.GELU(), nn.Linear(64, 1))
78
+
79
+ # Cayley-Menger determinant sign
80
+ self.cm_head = nn.Sequential(
81
+ nn.Linear(trunk_dim, 64), nn.GELU(),
82
+ nn.Linear(64, 1), nn.Tanh())
83
+
84
+ # Curvature binary
85
+ self.curved_head = nn.Sequential(
86
+ nn.Linear(trunk_dim, 32), nn.GELU(),
87
+ nn.Linear(32, 1), nn.Sigmoid())
88
+
89
+ # Curvature type (8-class)
90
+ self.curv_type_head = nn.Sequential(
91
+ nn.Linear(trunk_dim, 64), nn.GELU(),
92
+ nn.Linear(64, n_curvatures))
93
+
94
+ # Second classifier (arbiter analog)
95
+ self.refiner = nn.Sequential(
96
+ nn.Linear(trunk_dim, 128), nn.GELU(), nn.Dropout(0.1),
97
+ nn.Linear(128, n_classes))
98
+
99
+ # Confidence and blend
100
+ self.confidence_head = nn.Sequential(
101
+ nn.Linear(trunk_dim, 32), nn.GELU(),
102
+ nn.Linear(32, 1), nn.Sigmoid())
103
+ self.blend_head = nn.Sequential(
104
+ nn.Linear(trunk_dim, 32), nn.GELU(),
105
+ nn.Linear(32, 1), nn.Sigmoid())
106
+
107
+ def forward(self, grid, labels=None):
108
+ B = grid.shape[0]
109
+ x = grid.reshape(B, -1).float()
110
+ feat = self.trunk(x)
111
+
112
+ initial_logits = self.classifier(feat)
113
+ refined_logits = self.refiner(feat)
114
+
115
+ blend = self.blend_head(feat).squeeze(-1)
116
+ class_logits = (blend.unsqueeze(-1) * initial_logits +
117
+ (1 - blend.unsqueeze(-1)) * refined_logits)
118
+
119
+ conf = self.confidence_head(feat).squeeze(-1)
120
+
121
+ return {
122
+ "class_logits": class_logits,
123
+ "initial_logits": initial_logits,
124
+ "refined_logits": refined_logits,
125
+ "fill_ratios": self.fill_head(feat),
126
+ "peak_logits": self.peak_head(feat),
127
+ "overflows": self.overflow_head(feat),
128
+ "capacities": self.cap_head(feat),
129
+ "volume_pred": self.volume_head(feat).squeeze(-1),
130
+ "cm_pred": self.cm_head(feat).squeeze(-1),
131
+ "is_curved_pred": self.curved_head(feat),
132
+ "curv_type_logits": self.curv_type_head(feat),
133
+ "trajectory_logits": [refined_logits],
134
+ "flow_loss": torch.tensor(0.0, device=grid.device),
135
+ "refined_confidence": self.confidence_head(feat),
136
+ "blend_weight": blend,
137
+ "confidence": conf,
138
+ "alternation": torch.zeros(B, device=grid.device),
139
+ "features": feat,
140
+ }
141
+
142
+
143
+ # =============================================================================
144
+ # Loss Functions (identical to Cell 3)
145
+ # =============================================================================
146
+
147
+ def _safe_bce(inp, tgt):
148
+ with torch.amp.autocast('cuda', enabled=False):
149
+ return F.binary_cross_entropy(inp.float(), tgt.float())
150
+
151
+ def capacity_fill_loss(fr, dt): return _safe_bce(fr, dt)
152
+
153
+ def overflow_reg(on, dt):
154
+ pk = dt.sum(dim=-1).long() - 1
155
+ loss = sum(on[b, pk[b].item():].sum() for b in range(on.shape[0]))
156
+ return loss / (on.shape[0] + 1e-8)
157
+
158
+ def cap_diversity(c): return -c.var()
159
+ def peak_loss(l, t): return F.cross_entropy(l, t)
160
+ def cm_loss(p, t): return F.mse_loss(p, torch.sign(t))
161
+ def curved_bce(p, t): return _safe_bce(p.squeeze(-1), t)
162
+ def ctype_loss(l, t): return F.cross_entropy(l, t)
163
+
164
+
165
+ # =============================================================================
166
+ # Data — load cached or generate + cache
167
+ # =============================================================================
168
+
169
+ DATASET_PATH = Path("./cached_dataset.pt")
170
+ N_SAMPLES = 500000
171
+ SEED = 42
172
+
173
+ if DATASET_PATH.exists():
174
+ print(f"Loading cached dataset from {DATASET_PATH}...")
175
+ t0 = time.time()
176
+ _cached = torch.load(DATASET_PATH, weights_only=True)
177
+ if _cached["n_samples"] == N_SAMPLES and _cached["seed"] == SEED:
178
+ train_ds = ShapeDataset.__new__(ShapeDataset)
179
+ val_ds = ShapeDataset.__new__(ShapeDataset)
180
+ for k in ["grids", "labels", "dim_conf", "peak_dim", "volume",
181
+ "cm_det", "is_curved", "curvature"]:
182
+ setattr(train_ds, k, _cached["train"][k])
183
+ setattr(val_ds, k, _cached["val"][k])
184
+ print(f"Loaded {len(train_ds)} train + {len(val_ds)} val in {time.time()-t0:.1f}s")
185
+ else:
186
+ print(f"Cache mismatch — regenerating")
187
+ DATASET_PATH.unlink()
188
+
189
+ if not DATASET_PATH.exists():
190
+ print("Generating dataset...")
191
+ all_samples = generate_parallel(N_SAMPLES, seed=SEED, n_workers=8)
192
+ n_train = int(len(all_samples) * 0.8)
193
+ train_ds = ShapeDataset(all_samples[:n_train])
194
+ val_ds = ShapeDataset(all_samples[n_train:])
195
+
196
+ print(f"Caching to {DATASET_PATH}...")
197
+ cache_data = {
198
+ "n_samples": N_SAMPLES, "seed": SEED,
199
+ "train": {k: getattr(train_ds, k) for k in ["grids", "labels", "dim_conf",
200
+ "peak_dim", "volume", "cm_det", "is_curved", "curvature"]},
201
+ "val": {k: getattr(val_ds, k) for k in ["grids", "labels", "dim_conf",
202
+ "peak_dim", "volume", "cm_det", "is_curved", "curvature"]},
203
+ }
204
+ torch.save(cache_data, DATASET_PATH)
205
+ size_mb = DATASET_PATH.stat().st_size / 1e6
206
+ print(f"Cached: {size_mb:.0f}MB | {len(train_ds)} train + {len(val_ds)} val")
207
+
208
+ train_loader = torch.utils.data.DataLoader(
209
+ train_ds, batch_size=4096, shuffle=True,
210
+ num_workers=4, pin_memory=True, persistent_workers=True)
211
+ val_loader = torch.utils.data.DataLoader(
212
+ val_ds, batch_size=4096, shuffle=False,
213
+ num_workers=4, pin_memory=True, persistent_workers=True)
214
+
215
+
216
+ # =============================================================================
217
+ # Train
218
+ # =============================================================================
219
+
220
+ model = MLPBaseline().to(device)
221
+ n_params = sum(p.numel() for p in model.parameters())
222
+ print(f"MLPBaseline: {n_params:,} params")
223
+ print(f"(GeometricShapeClassifier was 1,852,870 params)")
224
+
225
+ if device.type == "cuda" and hasattr(torch, 'compile'):
226
+ try:
227
+ model = torch.compile(model, mode="default")
228
+ print("torch.compile: enabled")
229
+ except Exception as e:
230
+ print(f"torch.compile: skipped ({e})")
231
+
232
+ epochs = 80
233
+ lr = 3e-3
234
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
235
+ warmup_epochs = 5
236
+ def lr_lambda(ep):
237
+ if ep < warmup_epochs: return (ep + 1) / warmup_epochs
238
+ return 0.5 * (1 + math.cos(math.pi * (ep - warmup_epochs) / (epochs - warmup_epochs)))
239
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
240
+
241
+ w = {"cls": 1.0, "fill": 0.3, "peak": 0.3, "ovf": 0.05,
242
+ "div": 0.02, "vol": 0.1, "cm": 0.1, "curved": 0.2, "ctype": 0.2,
243
+ "arb_cls": 0.8, "arb_traj": 0.2, "arb_conf": 0.1, "flow": 0.5}
244
+
245
+ use_scaler = use_amp and amp_dtype == torch.float16
246
+ scaler = torch.amp.GradScaler('cuda', enabled=use_scaler)
247
+
248
+ print(f"\nAblation: MLPBaseline vs GeometricShapeClassifier")
249
+ print(f"Same loss ({len(w)} terms), same data, same schedule")
250
+ print(f"{'='*70}")
251
+
252
+ best_val_acc = 0
253
+ t_start = time.time()
254
+
255
+ for epoch in range(epochs):
256
+ t0 = time.time()
257
+ model.train()
258
+ correct, total = 0, 0
259
+ correct_init, correct_ref = 0, 0
260
+
261
+ for grid, label, dc, pd, vol, cm, ic, ct in train_loader:
262
+ grid = grid.to(device, non_blocking=True)
263
+ label = label.to(device, non_blocking=True)
264
+ dc = dc.to(device, non_blocking=True)
265
+ pd = pd.to(device, non_blocking=True)
266
+ vol = vol.to(device, non_blocking=True)
267
+ cm = cm.to(device, non_blocking=True)
268
+ ic = ic.to(device, non_blocking=True)
269
+ ct = ct.to(device, non_blocking=True)
270
+
271
+ grid = deform_grid(grid, p_dropout=0.05, p_add=0.05, p_shift=0.08)
272
+ optimizer.zero_grad(set_to_none=True)
273
+
274
+ with torch.amp.autocast('cuda', enabled=use_amp, dtype=amp_dtype):
275
+ out = model(grid, labels=label)
276
+
277
+ loss_first = (w["cls"] * F.cross_entropy(out["initial_logits"], label) +
278
+ w["fill"] * capacity_fill_loss(out["fill_ratios"], dc) +
279
+ w["peak"] * peak_loss(out["peak_logits"], pd) +
280
+ w["ovf"] * overflow_reg(out["overflows"], dc) +
281
+ w["div"] * cap_diversity(out["capacities"]) +
282
+ w["vol"] * F.mse_loss(out["volume_pred"], torch.log1p(vol)) +
283
+ w["cm"] * cm_loss(out["cm_pred"], cm) +
284
+ w["curved"] * curved_bce(out["is_curved_pred"], ic) +
285
+ w["ctype"] * ctype_loss(out["curv_type_logits"], ct))
286
+
287
+ loss_arb = w["arb_cls"] * F.cross_entropy(out["refined_logits"], label)
288
+ traj_loss = 0
289
+ for step_i, step_logits in enumerate(out["trajectory_logits"]):
290
+ step_weight = (step_i + 1) / len(out["trajectory_logits"])
291
+ traj_loss += step_weight * F.cross_entropy(step_logits, label)
292
+ traj_loss /= len(out["trajectory_logits"])
293
+ loss_arb += w["arb_traj"] * traj_loss
294
+ loss_arb += w["flow"] * out["flow_loss"]
295
+
296
+ with torch.no_grad():
297
+ is_correct = (out["refined_logits"].argmax(1) == label).float()
298
+ loss_arb += w["arb_conf"] * _safe_bce(
299
+ out["refined_confidence"].squeeze(-1), is_correct)
300
+
301
+ with torch.no_grad():
302
+ init_correct = (out["initial_logits"].argmax(1) == label).float()
303
+ ref_correct = (out["refined_logits"].argmax(1) == label).float()
304
+ blend_target = torch.where(init_correct >= ref_correct,
305
+ torch.ones_like(init_correct) * 0.8,
306
+ torch.ones_like(init_correct) * 0.2)
307
+ loss_arb += 0.1 * _safe_bce(out["blend_weight"], blend_target)
308
+
309
+ loss_blend = w["cls"] * F.cross_entropy(out["class_logits"], label)
310
+ loss = loss_first + loss_arb + loss_blend
311
+
312
+ scaler.scale(loss).backward()
313
+ scaler.unscale_(optimizer)
314
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
315
+ scaler.step(optimizer)
316
+ scaler.update()
317
+
318
+ correct += (out["class_logits"].argmax(1) == label).sum().item()
319
+ correct_init += (out["initial_logits"].argmax(1) == label).sum().item()
320
+ correct_ref += (out["refined_logits"].argmax(1) == label).sum().item()
321
+ total += grid.size(0)
322
+
323
+ scheduler.step()
324
+
325
+ if epoch == 0 and device.type == "cuda":
326
+ peak = torch.cuda.max_memory_allocated() / 1e9
327
+ print(f"VRAM peak: {peak:.2f}GB | throughput: {total/(time.time()-t0):.0f} samples/s")
328
+
329
+ # Validate
330
+ model.eval()
331
+ vc, vt, vcc, vct = 0, 0, 0, 0
332
+ vc_init, vc_ref = 0, 0
333
+
334
+ with torch.no_grad(), torch.amp.autocast('cuda', enabled=use_amp, dtype=amp_dtype):
335
+ for grid, label, dc, pd, vol, cm, ic, ct in val_loader:
336
+ grid = grid.to(device, non_blocking=True)
337
+ label = label.to(device, non_blocking=True)
338
+ ic = ic.to(device, non_blocking=True)
339
+ out = model(grid)
340
+ vc += (out["class_logits"].argmax(1) == label).sum().item()
341
+ vc_init += (out["initial_logits"].argmax(1) == label).sum().item()
342
+ vc_ref += (out["refined_logits"].argmax(1) == label).sum().item()
343
+ vt += grid.size(0)
344
+ vcc += ((out["is_curved_pred"].squeeze(-1) > 0.5).float() == ic).sum().item()
345
+ vct += grid.size(0)
346
+
347
+ val_acc = vc / vt
348
+ val_init = vc_init / vt
349
+ val_ref = vc_ref / vt
350
+ curved_acc = vcc / vct
351
+ marker = " *" if val_acc > best_val_acc else ""
352
+ if val_acc > best_val_acc:
353
+ best_val_acc = val_acc
354
+
355
+ dt = time.time() - t0
356
+ if (epoch + 1) % 10 == 0 or epoch == 0 or marker:
357
+ print(f"Ep {epoch+1:3d}/{epochs} [{dt:.1f}s] | "
358
+ f"blend {val_acc:.3f} init {val_init:.3f} arb {val_ref:.3f} | "
359
+ f"curved {curved_acc:.3f}{marker}")
360
+
361
+ total_time = time.time() - t_start
362
+ print(f"\nDone in {total_time:.0f}s ({total_time/60:.1f}min)")
363
+
364
+
365
+ # =============================================================================
366
+ # Per-Class Breakdown
367
+ # =============================================================================
368
+
369
+ print(f"\n{'='*70}")
370
+ print(f"Per-Class Results — MLPBaseline")
371
+ print(f"{'='*70}")
372
+
373
+ model.eval()
374
+ cc_b = {n: 0 for n in CLASS_NAMES}
375
+ cc_i = {n: 0 for n in CLASS_NAMES}
376
+ cc_r = {n: 0 for n in CLASS_NAMES}
377
+ ct_c = {n: 0 for n in CLASS_NAMES}
378
+
379
+ with torch.no_grad(), torch.amp.autocast('cuda', enabled=use_amp, dtype=amp_dtype):
380
+ for grid, label, *_ in val_loader:
381
+ grid = grid.to(device, non_blocking=True)
382
+ label = label.to(device, non_blocking=True)
383
+ out = model(grid)
384
+ pb = out["class_logits"].argmax(1)
385
+ pi = out["initial_logits"].argmax(1)
386
+ pr = out["refined_logits"].argmax(1)
387
+ for k in range(len(label)):
388
+ name = CLASS_NAMES[label[k].item()]
389
+ cc_b[name] += (pb[k] == label[k]).item()
390
+ cc_i[name] += (pi[k] == label[k]).item()
391
+ cc_r[name] += (pr[k] == label[k]).item()
392
+ ct_c[name] += 1
393
+
394
+ print(f"\n{'Class':22s} | {'Blend':>5s} {'Init':>5s} {'Arb':>5s} | "
395
+ f"{'Corr':>4s}/{'Tot':>4s} | {'Type':8s} Curvature")
396
+ print("-" * 85)
397
+ for name in CLASS_NAMES:
398
+ if ct_c[name] == 0: continue
399
+ ab = cc_b[name]/ct_c[name]
400
+ ai = cc_i[name]/ct_c[name]
401
+ ar = cc_r[name]/ct_c[name]
402
+ info = SHAPE_CATALOG[name]
403
+ print(f" {name:20s} | {ab:.3f} {ai:.3f} {ar:.3f} | "
404
+ f"{cc_b[name]:4d}/{ct_c[name]:4d} | "
405
+ f"{'CURVED' if info['curved'] else 'rigid':8s} {info['curvature']}")
406
+
407
+
408
+ # =============================================================================
409
+ # Summary Comparison
410
+ # =============================================================================
411
+
412
+ print(f"\n{'='*70}")
413
+ print(f"ABLATION SUMMARY")
414
+ print(f"{'='*70}")
415
+ print(f" MLPBaseline: {n_params:>10,} params | best val acc: {best_val_acc:.4f}")
416
+ print(f" GeometricShapeClassifier: 1,852,870 params | best val acc: 0.9022")
417
+ print(f" Delta: {n_params - 1852870:>+10,} params | "
418
+ f"delta acc: {best_val_acc - 0.9022:+.4f}")
419
+ print()
420
+ if best_val_acc >= 0.89:
421
+ print(" -> Loss is doing most of the work.")
422
+ print(" The composite multi-task signal is sufficient to discover")
423
+ print(" geometric structure without architectural inductive bias.")
424
+ elif best_val_acc >= 0.80:
425
+ print(" -> Architecture contributes meaningfully.")
426
+ print(" The loss provides signal but the geometric inductive bias")
427
+ print(" (capacity cascade, tracers, flow arbiter) adds real value.")
428
+ else:
429
+ print(" -> Architecture is critical.")
430
+ print(" The MLP cannot recover the same behavior from loss alone.")
431
+ print(" Geometric inductive bias is doing the heavy lifting.")
432
+ print(f"\n Curved detection: {curved_acc:.3f}")
433
+ print(f" Training time: {total_time:.0f}s ({total_time/60:.1f}min)")