AbstractPhil commited on
Commit
2aa715e
Β·
verified Β·
1 Parent(s): be11bf0

Create experiment_2/experiment_1_adam_retuning_backprop_adjustment_sweep.py

Browse files
experiment_2/experiment_1_adam_retuning_backprop_adjustment_sweep.py ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================================
2
+ # RIGID PATCHWORK CLASSIFIER + GATE SWEEP
3
+ #
4
+ # No conv4d. No composition paths. No splatting.
5
+ #
6
+ # Patchwork: partition 30 anchors into K compartments.
7
+ # Each compartment gets its own MLP that processes the triangulation
8
+ # distances for its assigned anchors. Compartment outputs concatenate.
9
+ # Final MLP β†’ classifier.
10
+ #
11
+ # Gate sweep: vary the CV gate tolerance and normal passthrough
12
+ # to find the behavior regime.
13
+ # ============================================================================
14
+
15
+ import math
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
+
23
+
24
+ # ══════════════════════════════════════════════════════════════════
25
+ # GEOMETRIC PRIMITIVES (production versions, differentiable)
26
+ # ══════════════════════════════════════════════════════════════════
27
+
28
+
29
+ def tangential_projection(grad, embedding):
30
+ emb_n = F.normalize(embedding.detach().float(), dim=-1)
31
+ grad_f = grad.float()
32
+ radial = (grad_f * emb_n).sum(dim=-1, keepdim=True) * emb_n
33
+ return (grad_f - radial).to(grad.dtype), radial.to(grad.dtype)
34
+
35
+
36
+ # ── Production Cayley-Menger (generic, differentiable) ──
37
+
38
+ def cayley_menger_vol2(pts):
39
+ """Differentiable pentachoron volumeΒ². Generic for any V vertices."""
40
+ pts = pts.float()
41
+ diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
42
+ d2 = (diff * diff).sum(-1)
43
+ B, V, _ = d2.shape
44
+ cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)
45
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
46
+ s = (-1.0)**V; f = math.factorial(V-1)
47
+ return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
48
+
49
+
50
+ def cv_loss(emb, target=0.2, n_samples=16):
51
+ """
52
+ Differentiable CV loss. Proper loss term, not gradient surgery.
53
+ Flows gradient through torch.stack β†’ torch.sqrt β†’ torch.std/mean.
54
+ """
55
+ B = emb.shape[0]
56
+ if B < 5: return torch.tensor(0.0, device=emb.device)
57
+ vols = []
58
+ for _ in range(n_samples):
59
+ idx = torch.randperm(B, device=emb.device)[:5]
60
+ v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
61
+ vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
62
+ stacked = torch.stack(vols)
63
+ cv = stacked.std() / (stacked.mean() + 1e-8)
64
+ return (cv - target).abs()
65
+
66
+
67
+ @torch.no_grad()
68
+ def cv_metric(emb, n_samples=200):
69
+ """Non-differentiable CV measurement for logging."""
70
+ B = emb.shape[0]
71
+ if B < 5: return 0.0
72
+ emb_f = emb.detach().float()
73
+ vols = []
74
+ for _ in range(n_samples):
75
+ idx = torch.randperm(B, device=emb.device)[:5]
76
+ v2 = cayley_menger_vol2(emb_f[idx].unsqueeze(0))
77
+ v = torch.sqrt(F.relu(v2[0]) + 1e-12).item()
78
+ if v > 0: vols.append(v)
79
+ if len(vols) < 10: return 0.0
80
+ vols_t = torch.tensor(vols)
81
+ return float(vols_t.std() / (vols_t.mean() + 1e-8))
82
+
83
+
84
+ # ── Autograd: tangential projection + separation only ──
85
+ # NO gradient injection. CV is a loss term, not gradient surgery.
86
+
87
+ class GeometricAutograd(torch.autograd.Function):
88
+ """
89
+ Gradient filtering only. Two operations:
90
+ 1. Tangential projection (keep gradients on hypersphere surface)
91
+ 2. Separation preservation (attenuate collapse toward nearest anchor)
92
+
93
+ CV regulation is handled by cv_loss in the training loop.
94
+ Not here. Loss terms flow gradient naturally. Surgery doesn't.
95
+ """
96
+
97
+ @staticmethod
98
+ def forward(ctx, x, embedding, anchors, tang_only, sep_strength):
99
+ ctx.save_for_backward(embedding, anchors)
100
+ ctx.tang_only = tang_only
101
+ ctx.sep_strength = sep_strength
102
+ return x
103
+
104
+ @staticmethod
105
+ def backward(ctx, grad_output):
106
+ embedding, anchors = ctx.saved_tensors
107
+ tang_only = ctx.tang_only
108
+ sep_strength = ctx.sep_strength
109
+
110
+ emb_n = F.normalize(embedding.detach().float(), dim=-1)
111
+ anchors_n = F.normalize(anchors.detach().float(), dim=-1)
112
+ grad_f = grad_output.float()
113
+
114
+ # 1. Tangential projection
115
+ tang, norm = tangential_projection(grad_f, emb_n)
116
+ corrected = tang + (1.0 - tang_only) * norm
117
+
118
+ # 2. Separation preservation
119
+ if sep_strength > 0:
120
+ cos_to_anchors = emb_n @ anchors_n.T
121
+ nearest_idx = cos_to_anchors.argmax(dim=-1)
122
+ nearest_anchor = anchors_n[nearest_idx]
123
+ toward_nearest = (corrected * nearest_anchor).sum(dim=-1, keepdim=True)
124
+ collapse_component = toward_nearest * nearest_anchor
125
+ is_collapsing = (toward_nearest > 0).float()
126
+ corrected = corrected - sep_strength * is_collapsing * collapse_component
127
+
128
+ return corrected.to(grad_output.dtype), None, None, None, None
129
+
130
+
131
+ # ── Anchor gradient filtering ──
132
+
133
+ class AnchorAutograd(torch.autograd.Function):
134
+ """Anchor gradients projected tangential per-anchor. No radial drift."""
135
+ @staticmethod
136
+ def forward(ctx, anchors, drift):
137
+ ctx.save_for_backward(anchors)
138
+ ctx.drift = drift
139
+ return anchors
140
+
141
+ @staticmethod
142
+ def backward(ctx, grad_output):
143
+ anchors, = ctx.saved_tensors
144
+ a_n = F.normalize(anchors.detach().float(), dim=-1)
145
+ grad_f = grad_output.float()
146
+ N = a_n.shape[0]
147
+ corrected = torch.zeros_like(grad_f)
148
+ for i in range(N):
149
+ g = grad_f[i]; a = a_n[i]
150
+ corrected[i] = (g - (g * a).sum() * a) * ctx.drift
151
+ return corrected.to(grad_output.dtype), None
152
+
153
+
154
+ # ── Additional forward losses (from bank research) ──
155
+
156
+ def anchor_spread_loss(anchors):
157
+ """Prevent anchor collapse. Off-diagonal cosineΒ² β†’ 0."""
158
+ a_n = F.normalize(anchors, dim=-1)
159
+ sim = a_n @ a_n.T
160
+ sim = sim - torch.diag(torch.diag(sim))
161
+ return sim.pow(2).mean()
162
+
163
+
164
+ def anchor_entropy_loss(emb, anchors, sharpness=10.0):
165
+ """Anchor assignment sharpness. Lower entropy = crisper triangulation."""
166
+ a_n = F.normalize(anchors, dim=-1)
167
+ probs = F.softmax(emb @ a_n.T * sharpness, dim=-1)
168
+ return -(probs * (probs + 1e-12).log()).sum(-1).mean()
169
+
170
+
171
+ def anchor_ortho_loss(anchors):
172
+ """Constellation orthogonality. Off-diagonal gram β†’ 0."""
173
+ a_n = F.normalize(anchors, dim=-1)
174
+ gram = a_n @ a_n.T
175
+ N = anchors.shape[0]
176
+ mask = ~torch.eye(N, dtype=bool, device=anchors.device)
177
+ return gram[mask].pow(2).mean()
178
+
179
+
180
+ def cluster_variance_loss(emb, anchors):
181
+ """Maximize cross-anchor differentiation. -var(per-anchor mean cos)."""
182
+ a_n = F.normalize(anchors, dim=-1)
183
+ per_anchor_mean = (emb @ a_n.T).mean(dim=0)
184
+ return -per_anchor_mean.var()
185
+
186
+
187
+ # ══════════════════════════════════════════════════════════════════
188
+ # CONSTELLATION (pure Xavier, no semantics)
189
+ # ══════════════════════════════════════════════════════════════════
190
+
191
+ class Constellation(nn.Module):
192
+ def __init__(self, n_anchors=30, d_embed=768):
193
+ super().__init__()
194
+ self.n_anchors = n_anchors
195
+ anchors = F.normalize(torch.randn(n_anchors, d_embed), dim=-1)
196
+ self.anchors = nn.Parameter(anchors)
197
+
198
+ self.register_buffer("rigidity", torch.zeros(n_anchors))
199
+ self.register_buffer("visit_count", torch.zeros(n_anchors))
200
+
201
+ def triangulate(self, emb):
202
+ anchors_n = F.normalize(self.anchors, dim=-1)
203
+ cos_sim = emb @ anchors_n.T # (B, N)
204
+ tri_dist = 1.0 - cos_sim # (B, N)
205
+ nearest = cos_sim.argmax(dim=-1) # (B,)
206
+ return tri_dist, nearest
207
+
208
+ @torch.no_grad()
209
+ def update_rigidity(self, tri_dist):
210
+ """
211
+ Rigidity by nearest-anchor assignment, NOT by class label.
212
+ Anchors are geometric reference points, not class proxies.
213
+ """
214
+ nearest = tri_dist.argmin(dim=-1) # (B,) β€” nearest anchor per sample
215
+ for i in range(self.n_anchors):
216
+ mask = nearest == i
217
+ if mask.sum() < 5: continue
218
+ self.visit_count[i] += mask.sum().float()
219
+ cluster_dists = tri_dist[mask]
220
+ spread = cluster_dists.std(dim=0).mean()
221
+ alpha = min(0.1, 10.0 / (self.visit_count[i] + 1))
222
+ old = self.rigidity[i]
223
+ self.rigidity[i] = (1 - alpha) * old + alpha * (1.0 / (spread + 0.01))
224
+
225
+ def health(self):
226
+ a = F.normalize(self.anchors.detach(), dim=-1)
227
+ cos = a @ a.T
228
+ mask = ~torch.eye(self.n_anchors, dtype=bool, device=a.device)
229
+ return {
230
+ "mean_cos": cos[mask].mean().item(),
231
+ "std_cos": cos[mask].std().item(),
232
+ "min_gap": (1 - cos[mask].max()).item(),
233
+ "max_gap": (1 - cos[mask].min()).item(),
234
+ }
235
+
236
+
237
+ # ══════════════════════════════════════════════════════════════════
238
+ # PATCHWORK: compartmentalized anchor groups β†’ MLPs β†’ concat
239
+ # ══════════════════════════════════════════════════════════════════
240
+
241
+ class Patchwork(nn.Module):
242
+ """
243
+ Partition N anchors into K compartments.
244
+ Each compartment has its own MLP processing the triangulation
245
+ distances for its anchors.
246
+
247
+ Compartment assignments are fixed at init (evenly split).
248
+ Each compartment MLP: (B, anchors_per_compartment) β†’ (B, d_comp)
249
+ All compartments concatenate β†’ (B, K * d_comp)
250
+ """
251
+
252
+ def __init__(self, n_anchors=30, n_compartments=6, d_comp=64):
253
+ super().__init__()
254
+ self.n_anchors = n_anchors
255
+ self.n_compartments = n_compartments
256
+ self.d_comp = d_comp
257
+
258
+ # Assign anchors to compartments (evenly)
259
+ assignments = torch.arange(n_anchors) % n_compartments
260
+ self.register_buffer("assignments", assignments)
261
+
262
+ # Per-compartment MLP
263
+ anchors_per = n_anchors // n_compartments
264
+ remainder = n_anchors % n_compartments
265
+
266
+ self.compartments = nn.ModuleList()
267
+ for k in range(n_compartments):
268
+ n_k = (assignments == k).sum().item()
269
+ self.compartments.append(nn.Sequential(
270
+ nn.Linear(n_k, d_comp * 2),
271
+ nn.GELU(),
272
+ nn.Linear(d_comp * 2, d_comp),
273
+ nn.LayerNorm(d_comp),
274
+ ))
275
+
276
+ def forward(self, tri_dist):
277
+ """
278
+ Args:
279
+ tri_dist: (B, N) triangulation distances to all anchors
280
+
281
+ Returns:
282
+ features: (B, K * d_comp)
283
+ """
284
+ parts = []
285
+ for k in range(self.n_compartments):
286
+ mask = self.assignments == k
287
+ comp_input = tri_dist[:, mask] # (B, n_k)
288
+ parts.append(self.compartments[k](comp_input)) # (B, d_comp)
289
+ return torch.cat(parts, dim=-1) # (B, K * d_comp)
290
+
291
+
292
+ # ══════════════════════════════════════════════════════════════════
293
+ # FULL MODEL
294
+ # ══════════════════════════════════════════════════════════════════
295
+
296
+ class PatchworkClassifier(nn.Module):
297
+ def __init__(self, n_classes=30, n_anchors=30, d_embed=768,
298
+ n_compartments=6, d_comp=64, d_hidden=256):
299
+ super().__init__()
300
+
301
+ # Image backbone
302
+ self.backbone = nn.Sequential(
303
+ nn.Conv2d(1, 32, 3, padding=1), nn.GELU(), nn.MaxPool2d(2),
304
+ nn.Conv2d(32, 64, 3, padding=1), nn.GELU(), nn.MaxPool2d(2),
305
+ nn.Conv2d(64, 128, 3, padding=1), nn.GELU(), nn.AdaptiveAvgPool2d(1),
306
+ )
307
+ self.embed_proj = nn.Sequential(
308
+ nn.Linear(128, d_embed), nn.LayerNorm(d_embed),
309
+ )
310
+
311
+ # Constellation
312
+ self.constellation = Constellation(n_anchors, d_embed)
313
+
314
+ # Patchwork
315
+ self.patchwork = Patchwork(n_anchors, n_compartments, d_comp)
316
+
317
+ # Funnel MLP
318
+ pw_dim = n_compartments * d_comp
319
+ self.mlp = nn.Sequential(
320
+ nn.Linear(pw_dim, d_hidden), nn.GELU(), nn.LayerNorm(d_hidden),
321
+ nn.Linear(d_hidden, d_hidden), nn.GELU(), nn.LayerNorm(d_hidden),
322
+ nn.Linear(d_hidden, n_classes),
323
+ )
324
+
325
+ def forward(self, x):
326
+ feat = self.backbone(x).flatten(1)
327
+ emb = F.normalize(self.embed_proj(feat), dim=-1)
328
+ tri_dist, nearest = self.constellation.triangulate(emb)
329
+ pw_feat = self.patchwork(tri_dist)
330
+ logits = self.mlp(pw_feat)
331
+ return logits, emb, tri_dist, nearest
332
+
333
+
334
+ # ══════════════════════════════════════════════════════════════════
335
+ # SHAPE RENDERERS (compact)
336
+ # ══════════════════════════════════════════════════════════════════
337
+
338
+ def _d(img, x0, y0, x1, y1, t=1):
339
+ n=max(int(max(abs(x1-x0),abs(y1-y0))*2),1); sz=img.shape[0]
340
+ for s in np.linspace(0,1,n):
341
+ px,py=int(x0+s*(x1-x0)),int(y0+s*(y1-y0))
342
+ for dx in range(-t,t+1):
343
+ for dy in range(-t,t+1):
344
+ nx,ny=px+dx,py+dy
345
+ if 0<=nx<sz and 0<=ny<sz: img[ny,nx]=1.0
346
+
347
+ def rpoly(nv,sz=32,p=0.15):
348
+ img=np.zeros((sz,sz),dtype=np.float32);cx,cy,r=sz/2,sz/2,sz*0.35
349
+ a=np.linspace(0,2*np.pi,nv,endpoint=False)+np.random.uniform(0,2*np.pi)
350
+ ri=r*(1+np.random.normal(0,p,nv))
351
+ pts=[(cx+ri[i]*np.cos(a[i]),cy+ri[i]*np.sin(a[i])) for i in range(nv)]
352
+ for i in range(nv): _d(img,*pts[i],*pts[(i+1)%nv])
353
+ return img
354
+
355
+ def rstar(np_,sz=32,p=0.12):
356
+ img=np.zeros((sz,sz),dtype=np.float32);cx,cy=sz/2,sz/2;ro,ri_=sz*0.38,sz*0.15
357
+ a=np.linspace(0,2*np.pi,np_*2,endpoint=False)+np.random.uniform(0,2*np.pi)
358
+ pts=[(cx+(ro if i%2==0 else ri_)*(1+np.random.normal(0,p))*np.cos(a[i]),
359
+ cy+(ro if i%2==0 else ri_)*(1+np.random.normal(0,p))*np.sin(a[i])) for i in range(len(a))]
360
+ for i in range(len(pts)): _d(img,*pts[i],*pts[(i+1)%len(pts)])
361
+ return img
362
+
363
+ def rcross(sz=32,p=0.15):
364
+ img=np.zeros((sz,sz),dtype=np.float32);cx,cy,arm=sz/2,sz/2,sz*0.3
365
+ for ab in [0,np.pi/2,np.pi,3*np.pi/2]:
366
+ a=ab+np.random.normal(0,p*0.3);r=arm*(1+np.random.normal(0,p))
367
+ _d(img,cx,cy,cx+r*np.cos(a),cy+r*np.sin(a),2)
368
+ return img
369
+
370
+ def rspiral(sz=32,p=0.1):
371
+ img=np.zeros((sz,sz),dtype=np.float32);cx,cy=sz/2,sz/2
372
+ for t in np.linspace(0,5*np.pi,200):
373
+ r=sz*0.015*t*(1+np.random.normal(0,p*0.3));x,y=int(cx+r*np.cos(t)),int(cy+r*np.sin(t))
374
+ if 0<=x<sz and 0<=y<sz: img[y,x]=1.0
375
+ return img
376
+
377
+ def rwave(sz=32,p=0.1):
378
+ img=np.zeros((sz,sz),dtype=np.float32);f=2+np.random.normal(0,0.3);amp=sz*0.15*(1+np.random.normal(0,p))
379
+ for x in range(sz):
380
+ y=int(sz/2+amp*np.sin(2*np.pi*f*x/sz))
381
+ if 0<=y<sz: img[y,x]=1.0
382
+ return img
383
+
384
+ def rheart(sz=32,p=0.1):
385
+ img=np.zeros((sz,sz),dtype=np.float32);cx,cy=sz/2,sz*0.45;s=sz*0.017*(1+np.random.normal(0,p))
386
+ for t in np.linspace(0,2*np.pi,300):
387
+ x=16*np.sin(t)**3;y=-(13*np.cos(t)-5*np.cos(2*t)-2*np.cos(3*t)-np.cos(4*t))
388
+ ix,iy=int(cx+x*s),int(cy+y*s)
389
+ if 0<=ix<sz and 0<=iy<sz: img[iy,ix]=1.0
390
+ return img
391
+
392
+ def rcrescent(sz=32,p=0.1):
393
+ img=np.zeros((sz,sz),dtype=np.float32);cx,cy,r=sz/2,sz/2,sz*0.35;r2=r*0.7;off=r*0.3
394
+ for a in np.linspace(0,2*np.pi,300):
395
+ x1,y1=cx+r*np.cos(a),cy+r*np.sin(a)
396
+ if math.sqrt((x1-cx-off)**2+(y1-cy)**2)>=r2*0.9:
397
+ ix,iy=int(x1),int(y1)
398
+ if 0<=ix<sz and 0<=iy<sz: img[iy,ix]=1.0
399
+ return img
400
+
401
+ def rellipse(sz=32,p=0.1):
402
+ img=np.zeros((sz,sz),dtype=np.float32);cx,cy=sz/2,sz/2
403
+ a,b=sz*0.38*(1+np.random.normal(0,p)),sz*0.22*(1+np.random.normal(0,p));rot=np.random.uniform(0,np.pi)
404
+ for t in np.linspace(0,2*np.pi,200):
405
+ x,y=a*np.cos(t),b*np.sin(t);ix,iy=int(cx+x*np.cos(rot)-y*np.sin(rot)),int(cy+x*np.sin(rot)+y*np.cos(rot))
406
+ if 0<=ix<sz and 0<=iy<sz: img[iy,ix]=1.0
407
+ return img
408
+
409
+ def rring(sz=32,p=0.1):
410
+ img=np.zeros((sz,sz),dtype=np.float32);cx,cy=sz/2,sz/2
411
+ r1,r2=sz*0.35*(1+np.random.normal(0,p)),sz*0.22*(1+np.random.normal(0,p))
412
+ for a in np.linspace(0,2*np.pi,300):
413
+ for r in [r1,r2]:
414
+ x,y=int(cx+r*np.cos(a)),int(cy+r*np.sin(a))
415
+ if 0<=x<sz and 0<=y<sz: img[y,x]=1.0
416
+ return img
417
+
418
+ def rarrow(sz=32,p=0.12):
419
+ img=np.zeros((sz,sz),dtype=np.float32);cx,cy=sz/2,sz/2
420
+ l=sz*0.35*(1+np.random.normal(0,p));h=l*0.35;a=np.random.uniform(0,2*np.pi)
421
+ x1,y1=cx-l*np.cos(a),cy-l*np.sin(a);x2,y2=cx+l*np.cos(a),cy+l*np.sin(a)
422
+ _d(img,x1,y1,x2,y2)
423
+ for da in [0.7,-0.7]: _d(img,x2,y2,x2-h*np.cos(a+da),y2-h*np.sin(a+da))
424
+ return img
425
+
426
+ def rchevron(sz=32,p=0.12):
427
+ img=np.zeros((sz,sz),dtype=np.float32);cx,cy=sz/2,sz/2
428
+ w,h=sz*0.3*(1+np.random.normal(0,p)),sz*0.25*(1+np.random.normal(0,p))
429
+ _d(img,cx-w,cy+h,cx,cy-h);_d(img,cx,cy-h,cx+w,cy+h)
430
+ return img
431
+
432
+ def rsemicirc(sz=32,p=0.1):
433
+ img=np.zeros((sz,sz),dtype=np.float32);cx,cy,r=sz/2,sz*0.6,sz*0.35
434
+ for a in np.linspace(np.pi,2*np.pi,150):
435
+ x,y=int(cx+r*np.cos(a)),int(cy+r*np.sin(a))
436
+ if 0<=x<sz and 0<=y<sz: img[y,x]=1.0
437
+ _d(img,cx-r,cy,cx+r,cy)
438
+ return img
439
+
440
+ NAMES = ["triangle","square","pentagon","hexagon","heptagon","octagon","nonagon",
441
+ "decagon","dodecagon","circle","ellipse","spiral","wave","crescent",
442
+ "star3","star4","star5","star6","star7","star8","cross","diamond",
443
+ "arrow","heart","ring","semicircle","trapezoid","parallelogram","rhombus","chevron"]
444
+
445
+ def gen_one(c,sz=32):
446
+ if c==0: return rpoly(3,sz,0.20)
447
+ if c==1: return rpoly(4,sz,0.12)
448
+ if c==2: return rpoly(5,sz,0.15)
449
+ if c==3: return rpoly(6,sz,0.10)
450
+ if c==4: return rpoly(7,sz,0.10)
451
+ if c==5: return rpoly(8,sz,0.08)
452
+ if c==6: return rpoly(9,sz,0.08)
453
+ if c==7: return rpoly(10,sz,0.07)
454
+ if c==8: return rpoly(12,sz,0.06)
455
+ if c==9: return rpoly(32,sz,0.03)
456
+ if c==10: return rellipse(sz)
457
+ if c==11: return rspiral(sz)
458
+ if c==12: return rwave(sz)
459
+ if c==13: return rcrescent(sz)
460
+ if c==14: return rstar(3,sz)
461
+ if c==15: return rstar(4,sz)
462
+ if c==16: return rstar(5,sz)
463
+ if c==17: return rstar(6,sz)
464
+ if c==18: return rstar(7,sz)
465
+ if c==19: return rstar(8,sz)
466
+ if c==20: return rcross(sz)
467
+ if c==21: return rpoly(4,sz,0.10)
468
+ if c==22: return rarrow(sz)
469
+ if c==23: return rheart(sz)
470
+ if c==24: return rring(sz)
471
+ if c==25: return rsemicirc(sz)
472
+ if c==26: return rpoly(4,sz,0.15)
473
+ if c==27: return rpoly(4,sz,0.18)
474
+ if c==28: return rpoly(4,sz,0.10)
475
+ if c==29: return rchevron(sz)
476
+ return rpoly(3,sz)
477
+
478
+ def gen_data(n_per=500, sz=32):
479
+ imgs, labels = [], []
480
+ for _ in range(n_per):
481
+ for c in range(30):
482
+ imgs.append(gen_one(c, sz)); labels.append(c)
483
+ imgs = torch.tensor(np.array(imgs)).unsqueeze(1)
484
+ labels = torch.tensor(labels, dtype=torch.long)
485
+ perm = torch.randperm(len(labels))
486
+ return imgs[perm], labels[perm]
487
+
488
+
489
+ # ══════════════════════════════════════════════════════════════════
490
+ # SINGLE TRAINING RUN
491
+ # ═══════════════════���══════════════════════════════════════════════
492
+
493
+ def train_once(tang_only=0.01, cv_weight=0.001, sep_strength=1.0,
494
+ anchor_drift=0.0, w_spread=0.0, w_entropy=0.0,
495
+ w_ortho=0.0, w_cluster=0.0,
496
+ use_autograd=True, epochs=30, seed=42, verbose=True):
497
+ """
498
+ Proven base: tang=0.01, sep=1.0, cv=0.001
499
+ New losses start at zero, layered in individually.
500
+ Adam, NOT AdamW. Geometry IS the regularization.
501
+ """
502
+ torch.manual_seed(seed); np.random.seed(seed)
503
+
504
+ train_imgs, train_labels = gen_data(n_per=500)
505
+ val_imgs, val_labels = gen_data(n_per=100)
506
+ train_imgs, train_labels = train_imgs.to(DEVICE), train_labels.to(DEVICE)
507
+ val_imgs, val_labels = val_imgs.to(DEVICE), val_labels.to(DEVICE)
508
+ n_train, n_val = len(train_labels), len(val_labels)
509
+
510
+ model = PatchworkClassifier(
511
+ n_classes=30, n_anchors=64, d_embed=768,
512
+ n_compartments=6, d_comp=64, d_hidden=256,
513
+ ).to(DEVICE)
514
+
515
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
516
+ BATCH = 256
517
+
518
+ history = []
519
+
520
+ for epoch in range(epochs):
521
+ model.train()
522
+ perm = torch.randperm(n_train, device=DEVICE)
523
+ total_loss, total_correct, n = 0, 0, 0
524
+
525
+ for i in range(0, n_train, BATCH):
526
+ idx = perm[i:i+BATCH]
527
+ if len(idx) < 4: continue
528
+
529
+ logits, emb, tri, nearest = model(train_imgs[idx])
530
+ labels = train_labels[idx]
531
+ anchors = model.constellation.anchors
532
+
533
+ if use_autograd and (tang_only > 0 or sep_strength > 0):
534
+ emb_corrected = GeometricAutograd.apply(
535
+ emb, emb, anchors, tang_only, sep_strength)
536
+ tri_g, _ = model.constellation.triangulate(emb_corrected)
537
+ pw_feat = model.patchwork(tri_g)
538
+ logits = model.mlp(pw_feat)
539
+
540
+ if use_autograd and anchor_drift > 0:
541
+ _ = AnchorAutograd.apply(anchors, anchor_drift)
542
+
543
+ # Task loss
544
+ l_cls = F.cross_entropy(logits, labels)
545
+
546
+ # Geometric losses (all differentiable, proven micro weights)
547
+ l_geo = torch.tensor(0.0, device=DEVICE)
548
+ if cv_weight > 0:
549
+ l_geo = l_geo + cv_weight * cv_loss(emb, target=0.2, n_samples=16)
550
+ if w_spread > 0:
551
+ l_geo = l_geo + w_spread * anchor_spread_loss(anchors)
552
+ if w_entropy > 0:
553
+ l_geo = l_geo + w_entropy * anchor_entropy_loss(emb, anchors)
554
+ if w_ortho > 0:
555
+ l_geo = l_geo + w_ortho * anchor_ortho_loss(anchors)
556
+ if w_cluster > 0:
557
+ l_geo = l_geo + w_cluster * cluster_variance_loss(emb, anchors)
558
+
559
+ loss = l_cls + l_geo
560
+ loss.backward()
561
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
562
+ optimizer.step(); optimizer.zero_grad(set_to_none=True)
563
+
564
+ model.constellation.update_rigidity(tri.detach())
565
+
566
+ total_correct += (logits.argmax(-1) == labels).sum().item()
567
+ total_loss += loss.item()
568
+ n += 1
569
+
570
+ train_acc = total_correct / n_train
571
+
572
+ # Val
573
+ model.eval()
574
+ with torch.no_grad():
575
+ vl, ve, vt, vn = model(val_imgs)
576
+ v_acc = (vl.argmax(-1) == val_labels).float().mean().item()
577
+ v_cv = cv_metric(ve, n_samples=100)
578
+
579
+ # Anchor health
580
+ health = model.constellation.health()
581
+
582
+ # Measure equidistance quality
583
+ a_n = F.normalize(model.constellation.anchors, dim=-1)
584
+ cos_mat = a_n @ a_n.T
585
+ mask = ~torch.eye(a_n.shape[0], dtype=bool, device=DEVICE)
586
+ equi_std = cos_mat[mask].std().item()
587
+
588
+ types = {"polygon": list(range(9)), "curve": list(range(9,14)),
589
+ "star": list(range(14,20)), "structure": list(range(20,30))}
590
+ ta = {}
591
+ for tname, tids in types.items():
592
+ tmask = torch.zeros(n_val, dtype=bool, device=DEVICE)
593
+ for tid in tids: tmask |= (val_labels == tid)
594
+ if tmask.sum() > 0:
595
+ ta[tname] = (vl.argmax(-1)[tmask] == val_labels[tmask]).float().mean().item()
596
+
597
+ history.append({
598
+ "epoch": epoch + 1, "train_acc": train_acc, "val_acc": v_acc,
599
+ "val_cv": v_cv, "equi_std": equi_std, "type_accs": ta,
600
+ })
601
+
602
+ if verbose and ((epoch + 1) % 10 == 0 or epoch == 0):
603
+ ta_str = " ".join(f"{t}={a:.2f}" for t, a in ta.items())
604
+ rig = model.constellation.rigidity
605
+ cv_delta = v_cv - 0.2
606
+ print(f" E{epoch+1:2d}: t={train_acc:.3f} v={v_acc:.3f} "
607
+ f"cv={v_cv:.4f}(Ξ”{cv_delta:+.3f}) equi={equi_std:.4f} "
608
+ f"rig={rig.mean():.1f}/{rig.max():.1f} [{ta_str}]")
609
+
610
+ health = model.constellation.health()
611
+ return history, health, model
612
+
613
+
614
+ # ═══════���══════════════════════════════════════════════════════════
615
+ # GATE SWEEP
616
+ # ══════════════════════════════════════════════════════════════════
617
+
618
+ print(f"\n{'='*65}")
619
+ print("GATE SWEEP: Varying gate parameters")
620
+ print(f"{'='*65}")
621
+ print(f" Device: {DEVICE}")
622
+ print(f" 30 classes, 15K train, 3K val")
623
+
624
+ configs = [
625
+ # (name, tang, cv_w, sep, drift, spread, entropy, ortho, cluster, use_ag)
626
+ # Proven base
627
+ ("raw_adam", 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, False),
628
+ ("proven", 0.01, 0.001, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, True),
629
+ # + one new loss each
630
+ ("+spread", 0.01, 0.001, 1.0, 0.0, 1e-3, 0.0, 0.0, 0.0, True),
631
+ ("+entropy", 0.01, 0.001, 1.0, 0.0, 0.0, 1e-4, 0.0, 0.0, True),
632
+ ("+ortho", 0.01, 0.001, 1.0, 0.0, 0.0, 0.0, 1e-3, 0.0, True),
633
+ ("+cluster", 0.01, 0.001, 1.0, 0.0, 0.0, 0.0, 0.0, 1e-4, True),
634
+ ("+drift", 0.01, 0.001, 1.0, 0.5, 0.0, 0.0, 0.0, 0.0, True),
635
+ # best combos
636
+ ("+spr+ort", 0.01, 0.001, 1.0, 0.0, 1e-3, 0.0, 1e-3, 0.0, True),
637
+ ("+all_micro", 0.01, 0.001, 1.0, 0.5, 1e-3, 1e-4, 1e-3, 1e-4, True),
638
+ ]
639
+
640
+ results = {}
641
+ for name, to, cw, sp, dr, ws, we, wo, wc, ua in configs:
642
+ print(f"\n ── {name} ──")
643
+ hist, health, _ = train_once(
644
+ tang_only=to, cv_weight=cw, sep_strength=sp,
645
+ anchor_drift=dr, w_spread=ws, w_entropy=we,
646
+ w_ortho=wo, w_cluster=wc,
647
+ use_autograd=ua, epochs=30, verbose=True)
648
+ final = hist[-1]
649
+ results[name] = {
650
+ "val_acc": final["val_acc"],
651
+ "train_acc": final["train_acc"],
652
+ "gap": final["train_acc"] - final["val_acc"],
653
+ "val_cv": final["val_cv"],
654
+ "equi_std": final["equi_std"],
655
+ "health": health,
656
+ "type_accs": final["type_accs"],
657
+ "cv_std": np.std([h["val_cv"] for h in hist]),
658
+ }
659
+
660
+
661
+ # ══════════════════════════════════════════════════════════════════
662
+ # SUMMARY
663
+ # ══════════════════════════════════════════════════════════════════
664
+
665
+ print(f"\n\n{'='*65}")
666
+ print("SWEEP RESULTS")
667
+ print(f"{'='*65}")
668
+
669
+ print(f"\n {'Config':<15} {'v_acc':>6} {'t_acc':>6} {'gap':>6} "
670
+ f"{'cv':>7} {'Ξ”cv':>7} {'eq_std':>7} {'poly':>5} {'curve':>5} {'star':>5} {'struct':>5}")
671
+ print(f" {'-'*90}")
672
+
673
+ for name in [c[0] for c in configs]:
674
+ r = results[name]
675
+ ta = r["type_accs"]
676
+ cv_delta = r["val_cv"] - 0.2
677
+ print(f" {name:<15} {r['val_acc']:>6.3f} {r['train_acc']:>6.3f} {r['gap']:>+6.3f} "
678
+ f"{r['val_cv']:>7.4f} {cv_delta:>+7.4f} {r['equi_std']:>7.4f} "
679
+ f"{ta.get('polygon',0):>5.2f} {ta.get('curve',0):>5.2f} "
680
+ f"{ta.get('star',0):>5.2f} {ta.get('structure',0):>5.2f}")
681
+
682
+ # Find best overall
683
+ best = max(results.items(), key=lambda x: x[1]["val_acc"])
684
+ print(f"\n Best accuracy: {best[0]} (val_acc={best[1]['val_acc']:.3f})")
685
+
686
+ # Find best structure accuracy (hardest category)
687
+ best_struct = max(results.items(), key=lambda x: x[1]["type_accs"].get("structure", 0))
688
+ print(f" Best structure: {best_struct[0]} (struct={best_struct[1]['type_accs'].get('structure',0):.3f})")
689
+
690
+ # Find closest to CV target 0.2
691
+ closest_cv = min(results.items(), key=lambda x: abs(x[1]["val_cv"] - 0.2))
692
+ print(f" Closest to CV=0.2: {closest_cv[0]} (cv={closest_cv[1]['val_cv']:.4f}, Ξ”={closest_cv[1]['val_cv']-0.2:+.4f})")
693
+
694
+ # Find most equidistant constellation
695
+ best_equi = min(results.items(), key=lambda x: x[1]["equi_std"])
696
+ print(f" Most equidistant: {best_equi[0]} (equi_std={best_equi[1]['equi_std']:.4f})")
697
+
698
+ # Find most stable CV trajectory
699
+ best_cv = min(results.items(), key=lambda x: x[1]["cv_std"])
700
+ print(f" Most stable CV: {best_cv[0]} (cv_std={best_cv[1]['cv_std']:.4f})")
701
+
702
+ print(f"\n{'='*65}")
703
+ print("DONE")
704
+ print(f"{'='*65}")