AbstractPhil commited on
Commit
aaf9465
Β·
verified Β·
1 Parent(s): 08d1095

Create multigenerational_trainer.py

Browse files
Files changed (1) hide show
  1. multigenerational_trainer.py +826 -0
multigenerational_trainer.py ADDED
@@ -0,0 +1,826 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================================
2
+ # DATA-DIVERSE GEOMETRIC EVOLUTION
3
+ #
4
+ # Each generation trains on differently-perturbed data.
5
+ # Consensus captures what's INVARIANT across perturbations.
6
+ #
7
+ # Gen 0: 2 founders, Dataset A (standard)
8
+ # β†’ GPA β†’ consensus anchors
9
+ #
10
+ # Gen 1: 2 students distilled from Gen 0 consensus
11
+ # Student S1: Dataset B (high noise, thick strokes)
12
+ # Student S2: Dataset C (thin strokes, shifted centers)
13
+ # β†’ GPA consensus of S1 + S2
14
+ #
15
+ # Gen 2: 3 offspring from Gen 1 consensus + 1 new founder on Dataset D
16
+ # β†’ GPA consensus of 4
17
+ #
18
+ # Gen 3: 5 models, each on Dataset E (identical perturbation style,
19
+ # different random samples)
20
+ # β†’ GPA consensus of 5
21
+ #
22
+ # Gen 4 (FINAL): 3 triplets, each selecting different 5 parents
23
+ # from the ENTIRE lineage pool
24
+ # ============================================================================
25
+
26
+ import math
27
+ import gc
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+
33
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
34
+
35
+ print("=" * 65)
36
+ print("DATA-DIVERSE GEOMETRIC EVOLUTION")
37
+ print("=" * 65)
38
+ print(f" Device: {DEVICE}")
39
+
40
+
41
+ # ══════════════════════════════════════════════════════════════════
42
+ # GEOMETRIC PRIMITIVES
43
+ # ══════════════════════════════════════════════════════════════════
44
+
45
+ def tangential_projection(grad, embedding):
46
+ emb_n = F.normalize(embedding.detach().float(), dim=-1)
47
+ grad_f = grad.float()
48
+ radial = (grad_f * emb_n).sum(dim=-1, keepdim=True) * emb_n
49
+ return (grad_f - radial).to(grad.dtype), radial.to(grad.dtype)
50
+
51
+ def cayley_menger_vol2(pts):
52
+ pts = pts.float()
53
+ diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
54
+ d2 = (diff * diff).sum(-1)
55
+ B, V, _ = d2.shape
56
+ cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)
57
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
58
+ s = (-1.0)**V; f = math.factorial(V-1)
59
+ return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
60
+
61
+ def cv_loss(emb, target=0.2, n_samples=16):
62
+ B = emb.shape[0]
63
+ if B < 5: return torch.tensor(0.0, device=emb.device)
64
+ vols = []
65
+ for _ in range(n_samples):
66
+ idx = torch.randperm(B, device=emb.device)[:5]
67
+ v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
68
+ vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
69
+ stacked = torch.stack(vols)
70
+ cv = stacked.std() / (stacked.mean() + 1e-8)
71
+ return (cv - target).abs()
72
+
73
+ @torch.no_grad()
74
+ def cv_metric(emb, n_samples=200):
75
+ B = emb.shape[0]
76
+ if B < 5: return 0.0
77
+ emb_f = emb.detach().float()
78
+ vols = []
79
+ for _ in range(n_samples):
80
+ idx = torch.randperm(B, device=emb.device)[:5]
81
+ v2 = cayley_menger_vol2(emb_f[idx].unsqueeze(0))
82
+ v = torch.sqrt(F.relu(v2[0]) + 1e-12).item()
83
+ if v > 0: vols.append(v)
84
+ if len(vols) < 10: return 0.0
85
+ a = torch.tensor(vols)
86
+ return float(a.std() / (a.mean() + 1e-8))
87
+
88
+ def anchor_spread_loss(anchors):
89
+ a_n = F.normalize(anchors, dim=-1)
90
+ sim = a_n @ a_n.T - torch.diag(torch.ones(anchors.shape[0], device=anchors.device))
91
+ return sim.pow(2).mean()
92
+
93
+ def anchor_entropy_loss(emb, anchors, sharpness=10.0):
94
+ a_n = F.normalize(anchors, dim=-1)
95
+ probs = F.softmax(emb @ a_n.T * sharpness, dim=-1)
96
+ return -(probs * (probs + 1e-12).log()).sum(-1).mean()
97
+
98
+ def anchor_ortho_loss(anchors):
99
+ a_n = F.normalize(anchors, dim=-1)
100
+ gram = a_n @ a_n.T
101
+ N = anchors.shape[0]
102
+ mask = ~torch.eye(N, dtype=bool, device=anchors.device)
103
+ return gram[mask].pow(2).mean()
104
+
105
+ def infonce(a, b, temperature=0.07):
106
+ a = F.normalize(a, dim=-1); b = F.normalize(b, dim=-1)
107
+ logits = (a @ b.T) / temperature
108
+ labels = torch.arange(logits.shape[0], device=logits.device)
109
+ return (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2
110
+
111
+ class EmbeddingAutograd(torch.autograd.Function):
112
+ @staticmethod
113
+ def forward(ctx, x, embedding, anchors, tang, sep):
114
+ ctx.save_for_backward(embedding, anchors)
115
+ ctx.tang = tang; ctx.sep = sep
116
+ return x
117
+ @staticmethod
118
+ def backward(ctx, grad_output):
119
+ embedding, anchors = ctx.saved_tensors
120
+ emb_n = F.normalize(embedding.detach().float(), dim=-1)
121
+ anchors_n = F.normalize(anchors.detach().float(), dim=-1)
122
+ grad_f = grad_output.float()
123
+ tang_grad, norm_grad = tangential_projection(grad_f, emb_n)
124
+ corrected = tang_grad + (1.0 - ctx.tang) * norm_grad
125
+ if ctx.sep > 0:
126
+ cos_to = emb_n @ anchors_n.T
127
+ nearest = anchors_n[cos_to.argmax(dim=-1)]
128
+ toward = (corrected * nearest).sum(dim=-1, keepdim=True)
129
+ collapse = toward * nearest
130
+ corrected = corrected - ctx.sep * (toward > 0).float() * collapse
131
+ return corrected.to(grad_output.dtype), None, None, None, None
132
+
133
+
134
+ # ══════════════════════════════════════════════════════════════════
135
+ # PROCRUSTES
136
+ # ══════════════════════════════════════════════════════════════════
137
+
138
+ def symmetric_inv_sqrt(cov, eps=1e-6):
139
+ evals, evecs = torch.linalg.eigh(cov)
140
+ return evecs @ torch.diag(torch.clamp(evals, min=eps).rsqrt()) @ evecs.T
141
+
142
+ def procrustes_align(source, target, n_align=10000):
143
+ N = min(n_align, source.shape[0], target.shape[0])
144
+ S = source[:N].float(); T = target[:N].float()
145
+ s_mean = S.mean(0, keepdim=True); Sc = S - s_mean; Ns = Sc.shape[0]
146
+ s_cov = (Sc.T @ Sc) / max(Ns-1, 1)
147
+ t_mean = T.mean(0, keepdim=True); Tc = T - t_mean
148
+ t_cov = (Tc.T @ Tc) / max(Ns-1, 1)
149
+ s_w = symmetric_inv_sqrt(s_cov); t_w = symmetric_inv_sqrt(t_cov)
150
+ Sc_w = F.normalize(Sc @ s_w, dim=-1); Tc_w = F.normalize(Tc @ t_w, dim=-1)
151
+ U, _, Vt = torch.linalg.svd(Tc_w.T @ Sc_w, full_matrices=False)
152
+ return {"rotation": U @ Vt, "source_mean": s_mean.squeeze(0), "source_whitener": s_w}
153
+
154
+ def apply_align(emb, info):
155
+ return (emb.float() - info["source_mean"]) @ info["source_whitener"] @ info["rotation"].T
156
+
157
+ def gpa_consensus(embeddings_list, n_iters=15):
158
+ N = len(embeddings_list)
159
+ cur = {i: e.float() for i, e in enumerate(embeddings_list)}
160
+ for it in range(n_iters):
161
+ mean = sum(cur[i] for i in range(N)) / N
162
+ delta = 0.0
163
+ new_cur = {}
164
+ for i in range(N):
165
+ info = procrustes_align(cur[i], mean)
166
+ new_cur[i] = apply_align(cur[i], info)
167
+ delta += (new_cur[i] - cur[i]).pow(2).mean().item()
168
+ cur = new_cur
169
+ if delta < 1e-8: break
170
+ mean = sum(cur[i] for i in range(N)) / N
171
+ return F.normalize(mean, dim=-1)
172
+
173
+ def consensus_anchors(consensus, n_anchors=1024):
174
+ """
175
+ K-means on consensus embeddings. Anchors discover their own
176
+ regions of the manifold independent of class boundaries.
177
+ """
178
+ emb = consensus.detach().float()
179
+ N, D = emb.shape
180
+
181
+ # Init: random subset
182
+ idx = torch.randperm(N)[:n_anchors]
183
+ centers = emb[idx].clone()
184
+
185
+ for _ in range(30):
186
+ # Assign
187
+ cos = emb @ F.normalize(centers, dim=-1).T
188
+ assignments = cos.argmax(dim=-1)
189
+ # Update
190
+ new_centers = torch.zeros_like(centers)
191
+ for k in range(n_anchors):
192
+ mask = assignments == k
193
+ if mask.sum() > 0:
194
+ new_centers[k] = emb[mask].mean(0)
195
+ else:
196
+ new_centers[k] = emb[torch.randint(N, (1,))].squeeze(0)
197
+ delta = (F.normalize(new_centers, dim=-1) - F.normalize(centers, dim=-1)).pow(2).sum()
198
+ centers = new_centers
199
+ if delta < 1e-6: break
200
+
201
+ return F.normalize(centers, dim=-1)
202
+
203
+
204
+ # ══════════════════════════════════════════════════════════════════
205
+ # MODEL
206
+ # ══════════════════════════════════════════════════════════════════
207
+
208
+ class Constellation(nn.Module):
209
+ def __init__(self, n_anchors=1024, d_embed=64, init_anchors=None):
210
+ super().__init__()
211
+ self.n_anchors = n_anchors
212
+ if init_anchors is not None:
213
+ self.anchors = nn.Parameter(init_anchors.clone())
214
+ else:
215
+ self.anchors = nn.Parameter(F.normalize(torch.randn(n_anchors, d_embed), dim=-1))
216
+ self.register_buffer("rigidity", torch.zeros(n_anchors))
217
+ self.register_buffer("visit_count", torch.zeros(n_anchors))
218
+ def triangulate(self, emb):
219
+ a = F.normalize(self.anchors, dim=-1)
220
+ cos = emb @ a.T
221
+ return 1.0 - cos, cos.argmax(dim=-1)
222
+ @torch.no_grad()
223
+ def update_rigidity(self, tri):
224
+ nearest = tri.argmin(dim=-1)
225
+ for i in range(self.n_anchors):
226
+ m = nearest == i
227
+ if m.sum() < 5: continue
228
+ self.visit_count[i] += m.sum().float()
229
+ sp = tri[m].std(dim=0).mean()
230
+ alpha = min(0.1, 10.0 / (self.visit_count[i] + 1))
231
+ self.rigidity[i] = (1-alpha)*self.rigidity[i] + alpha/(sp+0.01)
232
+
233
+ class Patchwork(nn.Module):
234
+ def __init__(self, n_anchors=1024, n_comp=6, d_comp=64):
235
+ super().__init__()
236
+ self.n_comp = n_comp
237
+ asgn = torch.arange(n_anchors) % n_comp
238
+ self.register_buffer("asgn", asgn)
239
+ self.comps = nn.ModuleList([nn.Sequential(
240
+ nn.Linear((asgn==k).sum().item(), d_comp*2), nn.GELU(),
241
+ nn.Linear(d_comp*2, d_comp), nn.LayerNorm(d_comp)) for k in range(n_comp)])
242
+ def forward(self, tri):
243
+ return torch.cat([self.comps[k](tri[:, self.asgn==k]) for k in range(self.n_comp)], -1)
244
+
245
+ class PatchworkClassifier(nn.Module):
246
+ def __init__(self, nc=30, na=1024, de=256, ncomp=6, dc=64, dh=256, init_a=None):
247
+ super().__init__()
248
+ if init_a is not None:
249
+ na = init_a.shape[0] # infer from provided anchors
250
+ self.backbone = nn.Sequential(
251
+ nn.Conv2d(1,32,3,padding=1), nn.GELU(), nn.MaxPool2d(2),
252
+ nn.Conv2d(32,64,3,padding=1), nn.GELU(), nn.MaxPool2d(2),
253
+ nn.Conv2d(64,128,3,padding=1), nn.GELU(), nn.AdaptiveAvgPool2d(1))
254
+ self.proj = nn.Sequential(nn.Linear(128, de), nn.LayerNorm(de))
255
+ self.constellation = Constellation(na, de, init_a)
256
+ self.patchwork = Patchwork(na, ncomp, dc)
257
+ self.mlp = nn.Sequential(
258
+ nn.Linear(ncomp*dc, dh), nn.GELU(), nn.LayerNorm(dh),
259
+ nn.Linear(dh, dh), nn.GELU(), nn.LayerNorm(dh),
260
+ nn.Linear(dh, nc))
261
+ def forward(self, x):
262
+ emb = F.normalize(self.proj(self.backbone(x).flatten(1)), dim=-1)
263
+ tri, near = self.constellation.triangulate(emb)
264
+ return self.mlp(self.patchwork(tri)), emb, tri, near
265
+ def encode(self, x):
266
+ return F.normalize(self.proj(self.backbone(x).flatten(1)), dim=-1)
267
+
268
+
269
+ # ══════════════════════════════════════════════════════════════════
270
+ # SHAPE RENDERERS WITH PERTURBATION PROFILES
271
+ # ══════════════════════════════════════════════════════════════════
272
+
273
+ def _d(img,x0,y0,x1,y1,t=1):
274
+ n=max(int(max(abs(x1-x0),abs(y1-y0))*2),1);sz=img.shape[0]
275
+ for s in np.linspace(0,1,n):
276
+ px,py=int(x0+s*(x1-x0)),int(y0+s*(y1-y0))
277
+ for dx in range(-t,t+1):
278
+ for dy in range(-t,t+1):
279
+ nx,ny=px+dx,py+dy
280
+ if 0<=nx<sz and 0<=ny<sz: img[ny,nx]=1.0
281
+
282
+ def rpoly(nv,sz=32,p=0.15,t=1,cx_off=0,cy_off=0):
283
+ img=np.zeros((sz,sz),dtype=np.float32);cx,cy,r=sz/2+cx_off,sz/2+cy_off,sz*0.35
284
+ a=np.linspace(0,2*np.pi,nv,endpoint=False)+np.random.uniform(0,2*np.pi)
285
+ ri=r*(1+np.random.normal(0,p,nv))
286
+ pts=[(cx+ri[i]*np.cos(a[i]),cy+ri[i]*np.sin(a[i])) for i in range(nv)]
287
+ for i in range(nv): _d(img,*pts[i],*pts[(i+1)%nv],t)
288
+ return img
289
+
290
+ def rstar(np_,sz=32,p=0.12,t=1,cx_off=0,cy_off=0):
291
+ img=np.zeros((sz,sz),dtype=np.float32);cx,cy=sz/2+cx_off,sz/2+cy_off;ro,ri_=sz*0.38,sz*0.15
292
+ a=np.linspace(0,2*np.pi,np_*2,endpoint=False)+np.random.uniform(0,2*np.pi)
293
+ pts=[(cx+(ro if i%2==0 else ri_)*(1+np.random.normal(0,p))*np.cos(a[i]),
294
+ cy+(ro if i%2==0 else ri_)*(1+np.random.normal(0,p))*np.sin(a[i])) for i in range(len(a))]
295
+ for i in range(len(pts)): _d(img,*pts[i],*pts[(i+1)%len(pts)],t)
296
+ return img
297
+
298
+ def rcross(sz=32,p=0.15,t=2,cx_off=0,cy_off=0):
299
+ img=np.zeros((sz,sz),dtype=np.float32);cx,cy,arm=sz/2+cx_off,sz/2+cy_off,sz*0.3
300
+ for ab in [0,np.pi/2,np.pi,3*np.pi/2]:
301
+ a=ab+np.random.normal(0,p*0.3);r=arm*(1+np.random.normal(0,p))
302
+ _d(img,cx,cy,cx+r*np.cos(a),cy+r*np.sin(a),t)
303
+ return img
304
+
305
+ def rspiral(sz=32,p=0.1,cx_off=0,cy_off=0):
306
+ img=np.zeros((sz,sz),dtype=np.float32);cx,cy=sz/2+cx_off,sz/2+cy_off
307
+ for t_ in np.linspace(0,5*np.pi,200):
308
+ r=sz*0.015*t_*(1+np.random.normal(0,p*0.3));x,y=int(cx+r*np.cos(t_)),int(cy+r*np.sin(t_))
309
+ if 0<=x<sz and 0<=y<sz: img[y,x]=1.0
310
+ return img
311
+
312
+ def rwave(sz=32,p=0.1,cx_off=0,cy_off=0):
313
+ img=np.zeros((sz,sz),dtype=np.float32);f=2+np.random.normal(0,0.3);amp=sz*0.15*(1+np.random.normal(0,p))
314
+ for x in range(sz):
315
+ y=int(sz/2+cy_off+amp*np.sin(2*np.pi*f*x/sz))
316
+ if 0<=y<sz: img[y,x]=1.0
317
+ return img
318
+
319
+ def rheart(sz=32,p=0.1,cx_off=0,cy_off=0):
320
+ img=np.zeros((sz,sz),dtype=np.float32);cx,cy=sz/2+cx_off,sz*0.45+cy_off;s=sz*0.017*(1+np.random.normal(0,p))
321
+ for t_ in np.linspace(0,2*np.pi,300):
322
+ x=16*np.sin(t_)**3;y=-(13*np.cos(t_)-5*np.cos(2*t_)-2*np.cos(3*t_)-np.cos(4*t_))
323
+ ix,iy=int(cx+x*s),int(cy+y*s)
324
+ if 0<=ix<sz and 0<=iy<sz: img[iy,ix]=1.0
325
+ return img
326
+
327
+ def rcrescent(sz=32,p=0.1,cx_off=0,cy_off=0):
328
+ img=np.zeros((sz,sz),dtype=np.float32);cx,cy,r=sz/2+cx_off,sz/2+cy_off,sz*0.35;r2=r*0.7;off=r*0.3
329
+ for a in np.linspace(0,2*np.pi,300):
330
+ x1,y1=cx+r*np.cos(a),cy+r*np.sin(a)
331
+ if math.sqrt((x1-cx-off)**2+(y1-cy)**2)>=r2*0.9:
332
+ if 0<=int(x1)<sz and 0<=int(y1)<sz: img[int(y1),int(x1)]=1.0
333
+ return img
334
+
335
+ def rellipse(sz=32,p=0.1,cx_off=0,cy_off=0):
336
+ img=np.zeros((sz,sz),dtype=np.float32);cx,cy=sz/2+cx_off,sz/2+cy_off
337
+ a,b=sz*0.38*(1+np.random.normal(0,p)),sz*0.22*(1+np.random.normal(0,p));rot=np.random.uniform(0,np.pi)
338
+ for t_ in np.linspace(0,2*np.pi,200):
339
+ x,y=a*np.cos(t_),b*np.sin(t_);ix,iy=int(cx+x*np.cos(rot)-y*np.sin(rot)),int(cy+x*np.sin(rot)+y*np.cos(rot))
340
+ if 0<=ix<sz and 0<=iy<sz: img[iy,ix]=1.0
341
+ return img
342
+
343
+ def rring(sz=32,p=0.1,cx_off=0,cy_off=0):
344
+ img=np.zeros((sz,sz),dtype=np.float32);cx,cy=sz/2+cx_off,sz/2+cy_off
345
+ r1,r2=sz*0.35*(1+np.random.normal(0,p)),sz*0.22*(1+np.random.normal(0,p))
346
+ for a in np.linspace(0,2*np.pi,300):
347
+ for r in [r1,r2]:
348
+ x,y=int(cx+r*np.cos(a)),int(cy+r*np.sin(a))
349
+ if 0<=x<sz and 0<=y<sz: img[y,x]=1.0
350
+ return img
351
+
352
+ def rarrow(sz=32,p=0.12,t=1,cx_off=0,cy_off=0):
353
+ img=np.zeros((sz,sz),dtype=np.float32);cx,cy=sz/2+cx_off,sz/2+cy_off
354
+ l=sz*0.35*(1+np.random.normal(0,p));h=l*0.35;a=np.random.uniform(0,2*np.pi)
355
+ x1,y1=cx-l*np.cos(a),cy-l*np.sin(a);x2,y2=cx+l*np.cos(a),cy+l*np.sin(a)
356
+ _d(img,x1,y1,x2,y2,t)
357
+ for da in [0.7,-0.7]: _d(img,x2,y2,x2-h*np.cos(a+da),y2-h*np.sin(a+da),t)
358
+ return img
359
+
360
+ def rchevron(sz=32,p=0.12,t=1,cx_off=0,cy_off=0):
361
+ img=np.zeros((sz,sz),dtype=np.float32);cx,cy=sz/2+cx_off,sz/2+cy_off
362
+ w,h=sz*0.3*(1+np.random.normal(0,p)),sz*0.25*(1+np.random.normal(0,p))
363
+ _d(img,cx-w,cy+h,cx,cy-h,t);_d(img,cx,cy-h,cx+w,cy+h,t)
364
+ return img
365
+
366
+ def rsemicirc(sz=32,p=0.1,t=1,cx_off=0,cy_off=0):
367
+ img=np.zeros((sz,sz),dtype=np.float32);cx,cy,r=sz/2+cx_off,sz*0.6+cy_off,sz*0.35
368
+ for a in np.linspace(np.pi,2*np.pi,150):
369
+ x,y=int(cx+r*np.cos(a)),int(cy+r*np.sin(a))
370
+ if 0<=x<sz and 0<=y<sz: img[y,x]=1.0
371
+ _d(img,cx-r,cy,cx+r,cy,t)
372
+ return img
373
+
374
+
375
+ # ── Dataset profiles ──
376
+
377
+ PROFILES = {
378
+ "A": {"p_scale": 1.0, "thickness": 1, "noise": 0.0, "shift": 0}, # standard
379
+ "B": {"p_scale": 1.5, "thickness": 2, "noise": 0.05, "shift": 0}, # noisy, thick
380
+ "C": {"p_scale": 0.7, "thickness": 1, "noise": 0.0, "shift": 3}, # precise, shifted
381
+ "D": {"p_scale": 1.2, "thickness": 1, "noise": 0.03, "shift": 2}, # moderate noise+shift
382
+ "E": {"p_scale": 1.0, "thickness": 1, "noise": 0.02, "shift": 1}, # gentle augmentation
383
+ }
384
+
385
+ def gen_one(c, sz=32, profile="A"):
386
+ pr = PROFILES[profile]
387
+ ps = pr["p_scale"]; t = pr["thickness"]; sh = pr["shift"]
388
+ cx_off = np.random.randint(-sh, sh+1) if sh > 0 else 0
389
+ cy_off = np.random.randint(-sh, sh+1) if sh > 0 else 0
390
+ base_p = [0.20,0.12,0.15,0.10,0.10,0.08,0.08,0.07,0.06,0.03,
391
+ 0.10,0.10,0.10,0.10,0.12,0.12,0.12,0.12,0.12,0.12,
392
+ 0.15,0.10,0.12,0.10,0.10,0.10,0.15,0.18,0.10,0.12]
393
+ p = base_p[c] * ps
394
+ kw = {"sz": sz, "cx_off": cx_off, "cy_off": cy_off}
395
+ R = [lambda: rpoly(3,p=p,t=t,**kw), lambda: rpoly(4,p=p,t=t,**kw),
396
+ lambda: rpoly(5,p=p,t=t,**kw), lambda: rpoly(6,p=p,t=t,**kw),
397
+ lambda: rpoly(7,p=p,t=t,**kw), lambda: rpoly(8,p=p,t=t,**kw),
398
+ lambda: rpoly(9,p=p,t=t,**kw), lambda: rpoly(10,p=p,t=t,**kw),
399
+ lambda: rpoly(12,p=p,t=t,**kw), lambda: rpoly(32,p=p*0.3,t=t,**kw),
400
+ lambda: rellipse(p=p,**kw), lambda: rspiral(p=p,**kw),
401
+ lambda: rwave(p=p,**kw), lambda: rcrescent(p=p,**kw),
402
+ lambda: rstar(3,p=p,t=t,**kw), lambda: rstar(4,p=p,t=t,**kw),
403
+ lambda: rstar(5,p=p,t=t,**kw), lambda: rstar(6,p=p,t=t,**kw),
404
+ lambda: rstar(7,p=p,t=t,**kw), lambda: rstar(8,p=p,t=t,**kw),
405
+ lambda: rcross(p=p,t=t,**kw), lambda: rpoly(4,p=p,t=t,**kw),
406
+ lambda: rarrow(p=p,t=t,**kw), lambda: rheart(p=p,**kw),
407
+ lambda: rring(p=p,**kw), lambda: rsemicirc(p=p,t=t,**kw),
408
+ lambda: rpoly(4,p=p*1.2,t=t,**kw), lambda: rpoly(4,p=p*1.5,t=t,**kw),
409
+ lambda: rpoly(4,p=p,t=t,**kw), lambda: rchevron(p=p,t=t,**kw)]
410
+ img = R[c]()
411
+ if pr["noise"] > 0:
412
+ img = img + np.random.normal(0, pr["noise"], img.shape).astype(np.float32)
413
+ img = np.clip(img, 0, 1)
414
+ return img
415
+
416
+ def gen_data(n_per=500, sz=32, profile="A", seed=None):
417
+ if seed is not None: np.random.seed(seed)
418
+ imgs, labels = [], []
419
+ for _ in range(n_per):
420
+ for c in range(30):
421
+ imgs.append(gen_one(c, sz, profile)); labels.append(c)
422
+ imgs = torch.tensor(np.array(imgs)).unsqueeze(1)
423
+ labels = torch.tensor(labels, dtype=torch.long)
424
+ perm = torch.randperm(len(labels))
425
+ return imgs[perm], labels[perm]
426
+
427
+ TYPES = {"polygon": list(range(9)), "curve": list(range(9,14)),
428
+ "star": list(range(14,20)), "structure": list(range(20,30))}
429
+
430
+ def eval_model(model, imgs, labels):
431
+ model.eval()
432
+ with torch.no_grad():
433
+ vl, ve, _, _ = model(imgs)
434
+ acc = (vl.argmax(-1) == labels).float().mean().item()
435
+ cv = cv_metric(ve)
436
+ ta = {}
437
+ for tn, tids in TYPES.items():
438
+ tm = torch.zeros(len(labels), dtype=bool, device=imgs.device)
439
+ for tid in tids: tm |= (labels == tid)
440
+ if tm.sum() > 0: ta[tn] = (vl.argmax(-1)[tm] == labels[tm]).float().mean().item()
441
+ return acc, cv, ta
442
+
443
+ def fmt_ta(ta):
444
+ return " ".join(f"{t}={a:.2f}" for t, a in ta.items())
445
+
446
+
447
+ # ══════════════════════════════════════���═══════════════════════════
448
+ # TRAINING FUNCTIONS
449
+ # ══════════════════════════════════════════════════════════════════
450
+
451
+ GEO_CFG = {"tang": 0.01, "sep": 1.0, "cv_w": 0.001, "spr": 1e-3, "ort": 1e-3, "ent": 1e-4}
452
+
453
+ def train_founder(model, tr_imgs, tr_labels, use_geo=True, epochs=30, tag=""):
454
+ opt = torch.optim.Adam(model.parameters(), lr=1e-3)
455
+ BATCH = 256; nt = len(tr_labels)
456
+ for ep in range(epochs):
457
+ model.train(); perm = torch.randperm(nt, device=DEVICE); tc = 0
458
+ for i in range(0, nt, BATCH):
459
+ idx = perm[i:i+BATCH]
460
+ if len(idx) < 4: continue
461
+ lo, emb, tri, _ = model(tr_imgs[idx]); lab = tr_labels[idx]
462
+ anc = model.constellation.anchors
463
+ if use_geo:
464
+ eg = EmbeddingAutograd.apply(emb, emb, anc, GEO_CFG["tang"], GEO_CFG["sep"])
465
+ tg, _ = model.constellation.triangulate(eg)
466
+ lo = model.mlp(model.patchwork(tg))
467
+ l = F.cross_entropy(lo, lab)
468
+ lg = torch.tensor(0.0, device=DEVICE)
469
+ if use_geo:
470
+ lg += GEO_CFG["cv_w"] * cv_loss(emb)
471
+ lg += GEO_CFG["spr"] * anchor_spread_loss(anc)
472
+ lg += GEO_CFG["ort"] * anchor_ortho_loss(anc)
473
+ lg += GEO_CFG["ent"] * anchor_entropy_loss(emb, anc)
474
+ (l + lg).backward()
475
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
476
+ opt.step(); opt.zero_grad(set_to_none=True)
477
+ model.constellation.update_rigidity(tri.detach())
478
+ tc += (lo.argmax(-1) == lab).sum().item()
479
+ if (ep+1) % 10 == 0 or ep == 0:
480
+ acc, cv, ta = eval_model(model, val_imgs, val_labels)
481
+ print(f" {tag}E{ep+1:2d}: t={tc/nt:.3f} v={acc:.3f} cv={cv:.4f} [{fmt_ta(ta)}]")
482
+
483
+ def train_distilled(model, tr_imgs, tr_labels, consensus, epochs=30, tag=""):
484
+ opt = torch.optim.Adam(model.parameters(), lr=1e-3)
485
+ BATCH = 256; nt = len(tr_labels)
486
+ for ep in range(epochs):
487
+ model.train(); perm = torch.randperm(nt, device=DEVICE); tc = 0
488
+ for i in range(0, nt, BATCH):
489
+ idx = perm[i:i+BATCH]
490
+ if len(idx) < 4: continue
491
+ lo, emb, tri, _ = model(tr_imgs[idx]); lab = tr_labels[idx]; tgt = consensus[idx]
492
+ anc = model.constellation.anchors
493
+ eg = EmbeddingAutograd.apply(emb, emb, anc, GEO_CFG["tang"], GEO_CFG["sep"])
494
+ tg, _ = model.constellation.triangulate(eg)
495
+ lo = model.mlp(model.patchwork(tg))
496
+ l_cls = F.cross_entropy(lo, lab)
497
+ l_nce = infonce(emb, tgt)
498
+ l_mse = F.mse_loss(emb, tgt)
499
+ l_cv = GEO_CFG["cv_w"] * cv_loss(emb)
500
+ l_ent = GEO_CFG["ent"] * anchor_entropy_loss(emb, anc)
501
+ (l_cls + 0.5*l_nce + 0.5*l_mse + l_cv + l_ent).backward()
502
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
503
+ opt.step(); opt.zero_grad(set_to_none=True)
504
+ model.constellation.update_rigidity(tri.detach())
505
+ tc += (lo.argmax(-1) == lab).sum().item()
506
+ if (ep+1) % 10 == 0 or ep == 0:
507
+ acc, cv, ta = eval_model(model, val_imgs, val_labels)
508
+ print(f" {tag}E{ep+1:2d}: t={tc/nt:.3f} v={acc:.3f} cv={cv:.4f} [{fmt_ta(ta)}]")
509
+
510
+
511
+ # ══════════════════════════════════════════════════════════════════
512
+ # VALIDATION DATA (always Dataset A β€” standard, consistent eval)
513
+ # ══════════════════════════════════════════════════════════════════
514
+
515
+ print(f"\n Generating validation data (Dataset A)...")
516
+ val_imgs, val_labels = gen_data(n_per=100, profile="A", seed=999)
517
+ val_imgs, val_labels = val_imgs.to(DEVICE), val_labels.to(DEVICE)
518
+ print(f" Val: {len(val_labels):,}")
519
+
520
+ all_results = {}
521
+ all_models = {} # keep references for final triplet parent selection
522
+
523
+
524
+ # ══════════════════════════════════════════════════════════════════
525
+ # GENERATION 0: 2 FOUNDERS on Dataset A
526
+ # ══════════════════════════════════════════════════════════════════
527
+
528
+ print(f"\n{'='*65}")
529
+ print("GEN 0: 2 FOUNDERS β€” Dataset A")
530
+ print(f"{'='*65}")
531
+
532
+ tr_A, lb_A = gen_data(n_per=500, profile="A", seed=42)
533
+ tr_A, lb_A = tr_A.to(DEVICE), lb_A.to(DEVICE)
534
+
535
+ for name, use_geo, sd in [("F0a", False, 100), ("F0b", True, 200)]:
536
+ print(f"\n ── {name} ──")
537
+ torch.manual_seed(sd)
538
+ m = PatchworkClassifier(init_a=None).to(DEVICE)
539
+ train_founder(m, tr_A, lb_A, use_geo=use_geo, tag=f"[{name}] ")
540
+ acc, cv, ta = eval_model(m, val_imgs, val_labels)
541
+ all_results[name] = {"acc": acc, "cv": cv, "ta": ta, "gen": 0}
542
+ all_models[name] = m
543
+ print(f" β†’ {name}: val={acc:.3f}")
544
+
545
+ # GPA consensus
546
+ print(f"\n GPA alignment (Gen 0)...")
547
+ embs_g0 = {n: m.encode(tr_A).detach() for n, m in all_models.items() if n.startswith("F0")}
548
+ cons_g0 = gpa_consensus(list(embs_g0.values()))
549
+ anc_g0 = consensus_anchors(cons_g0)
550
+ print(f" Consensus CV: {cv_metric(cons_g0[:2000]):.4f}")
551
+
552
+
553
+ # ══════════════════════════════════════════════════════════════════
554
+ # GENERATION 1: 2 STUDENTS β€” Dataset B and Dataset C
555
+ # ══════════════════════════════════════════════════════════════════
556
+
557
+ print(f"\n{'='*65}")
558
+ print("GEN 1: 2 STUDENTS β€” Datasets B and C")
559
+ print(f"{'='*65}")
560
+
561
+ tr_B, lb_B = gen_data(n_per=500, profile="B", seed=300)
562
+ tr_C, lb_C = gen_data(n_per=500, profile="C", seed=400)
563
+ tr_B, lb_B = tr_B.to(DEVICE), lb_B.to(DEVICE)
564
+ tr_C, lb_C = tr_C.to(DEVICE), lb_C.to(DEVICE)
565
+
566
+ # Need consensus targets indexed to each dataset's label ordering
567
+ # Since gen_data shuffles, we recompute consensus for each dataset
568
+ cons_g0_B = gpa_consensus([all_models["F0a"].encode(tr_B).detach(), all_models["F0b"].encode(tr_B).detach()])
569
+ cons_g0_C = gpa_consensus([all_models["F0a"].encode(tr_C).detach(), all_models["F0b"].encode(tr_C).detach()])
570
+
571
+ for name, tr, lb, cons, sd in [("G1_B", tr_B, lb_B, cons_g0_B, 301),
572
+ ("G1_C", tr_C, lb_C, cons_g0_C, 401)]:
573
+ print(f"\n ── {name} ──")
574
+ torch.manual_seed(sd)
575
+ m = PatchworkClassifier(init_a=consensus_anchors(cons)).to(DEVICE)
576
+ train_distilled(m, tr, lb, cons, tag=f"[{name}] ")
577
+ acc, cv, ta = eval_model(m, val_imgs, val_labels)
578
+ all_results[name] = {"acc": acc, "cv": cv, "ta": ta, "gen": 1}
579
+ all_models[name] = m
580
+ print(f" β†’ {name}: val={acc:.3f}")
581
+
582
+ del embs_g0; gc.collect(); torch.cuda.empty_cache()
583
+
584
+
585
+ # ══════════════════════════════════════════════════════════════════
586
+ # GENERATION 2: 3 OFFSPRING from G1 + 1 new founder, Dataset D
587
+ # ══════════════════════════════════════════════════════════════════
588
+
589
+ print(f"\n{'='*65}")
590
+ print("GEN 2: 3 OFFSPRING + new founder β€” Dataset D")
591
+ print(f"{'='*65}")
592
+
593
+ tr_D, lb_D = gen_data(n_per=500, profile="D", seed=500)
594
+ tr_D, lb_D = tr_D.to(DEVICE), lb_D.to(DEVICE)
595
+
596
+ # New founder on Dataset D
597
+ print(f"\n ── New founder (F1_D) ──")
598
+ torch.manual_seed(501)
599
+ f1d = PatchworkClassifier(init_a=None).to(DEVICE)
600
+ train_founder(f1d, tr_D, lb_D, use_geo=True, tag="[F1_D] ")
601
+ acc_f1d, _, _ = eval_model(f1d, val_imgs, val_labels)
602
+ all_results["F1_D"] = {"acc": acc_f1d, "cv": 0, "ta": {}, "gen": 1}
603
+ all_models["F1_D"] = f1d
604
+
605
+ # GPA from G1 + new founder (encode on Dataset D for consensus)
606
+ print(f"\n GPA alignment (G1 + F1_D on Dataset D)...")
607
+ g2_parents = ["G1_B", "G1_C", "F1_D"]
608
+ embs_g2 = [all_models[n].encode(tr_D).detach() for n in g2_parents]
609
+ cons_g2 = gpa_consensus(embs_g2)
610
+ anc_g2 = consensus_anchors(cons_g2)
611
+ print(f" Consensus CV: {cv_metric(cons_g2[:2000]):.4f}")
612
+
613
+ for i in range(3):
614
+ name = f"G2_{i}"
615
+ print(f"\n ── {name} ──")
616
+ torch.manual_seed(600 + i)
617
+ m = PatchworkClassifier(init_a=anc_g2).to(DEVICE)
618
+ train_distilled(m, tr_D, lb_D, cons_g2, tag=f"[{name}] ")
619
+ acc, cv, ta = eval_model(m, val_imgs, val_labels)
620
+ all_results[name] = {"acc": acc, "cv": cv, "ta": ta, "gen": 2}
621
+ all_models[name] = m
622
+ print(f" β†’ {name}: val={acc:.3f}")
623
+
624
+
625
+ # ══════════════════════════════════════════════════════════════════
626
+ # GENERATION 3: 5 MODELS β€” Dataset E (identical perturbation,
627
+ # different random samples)
628
+ # ══════════════════════════════════════════════════════════════════
629
+
630
+ print(f"\n{'='*65}")
631
+ print("GEN 3: 5 MODELS β€” Dataset E (identical profile, varied samples)")
632
+ print(f"{'='*65}")
633
+
634
+ # GPA from all G2 + new founder
635
+ g3_parents = [n for n in all_models if n.startswith("G2_")]
636
+ print(f" GPA alignment ({len(g3_parents)} G2 parents)...")
637
+
638
+ # Each Gen 3 model gets its own Dataset E sample
639
+ g3_models = []
640
+ for j in range(5):
641
+ name = f"G3_{j}"
642
+ tr_Ej, lb_Ej = gen_data(n_per=500, profile="E", seed=700 + j * 10)
643
+ tr_Ej, lb_Ej = tr_Ej.to(DEVICE), lb_Ej.to(DEVICE)
644
+
645
+ # Consensus from G2 parents on this dataset
646
+ embs_j = [all_models[n].encode(tr_Ej).detach() for n in g3_parents]
647
+ cons_j = gpa_consensus(embs_j)
648
+ anc_j = consensus_anchors(cons_j)
649
+
650
+ print(f"\n ── {name} ──")
651
+ torch.manual_seed(700 + j)
652
+ m = PatchworkClassifier(init_a=anc_j).to(DEVICE)
653
+ train_distilled(m, tr_Ej, lb_Ej, cons_j, tag=f"[{name}] ")
654
+ acc, cv, ta = eval_model(m, val_imgs, val_labels)
655
+ all_results[name] = {"acc": acc, "cv": cv, "ta": ta, "gen": 3}
656
+ all_models[name] = m
657
+ g3_models.append(name)
658
+ print(f" β†’ {name}: val={acc:.3f}")
659
+
660
+ del tr_Ej, lb_Ej; gc.collect(); torch.cuda.empty_cache()
661
+
662
+
663
+ # ══════════════════════════════════════════════════════════════════
664
+ # GENERATION 4 (FINAL): 3 TRIPLETS β€” each selects different 5
665
+ # parents from the FULL lineage
666
+ # ══════════════════════════════════════════════════════════════════
667
+
668
+ print(f"\n{'='*65}")
669
+ print("GEN 4 (FINAL): 3 TRIPLETS β€” cross-lineage parent selection")
670
+ print(f"{'='*65}")
671
+
672
+ # Sort all models by accuracy for parent selection
673
+ ranked = sorted(all_results.items(), key=lambda x: -x[1]["acc"])
674
+ ranked_names = [n for n, _ in ranked if n in all_models]
675
+
676
+ # Three different parent selection strategies
677
+ parent_sets = {
678
+ # Top 5 overall
679
+ "T4_best5": ranked_names[:5],
680
+ # Best from each generation
681
+ "T4_cross": [],
682
+ # Diverse: top + bottom + middle
683
+ "T4_diverse": [],
684
+ }
685
+
686
+ # Cross-generational: pick best from each gen
687
+ for gen in range(4):
688
+ gen_models = [(n, r) for n, r in ranked if r["gen"] == gen and n in all_models]
689
+ if gen_models:
690
+ parent_sets["T4_cross"].append(gen_models[0][0])
691
+ # Pad to 5 if needed
692
+ while len(parent_sets["T4_cross"]) < 5:
693
+ for n in ranked_names:
694
+ if n not in parent_sets["T4_cross"]:
695
+ parent_sets["T4_cross"].append(n); break
696
+
697
+ # Diverse: positions 0, 2, 4, 6, 8 from ranking
698
+ for idx in [0, 2, 4, 6, 8]:
699
+ if idx < len(ranked_names):
700
+ parent_sets["T4_diverse"].append(ranked_names[idx])
701
+
702
+ # Fresh eval data for final generation
703
+ tr_final, lb_final = gen_data(n_per=500, profile="A", seed=888)
704
+ tr_final, lb_final = tr_final.to(DEVICE), lb_final.to(DEVICE)
705
+
706
+ for name, parents in parent_sets.items():
707
+ print(f"\n ── {name} (parents: {parents}) ──")
708
+ embs_fin = [all_models[p].encode(tr_final).detach() for p in parents]
709
+ cons_fin = gpa_consensus(embs_fin)
710
+ anc_fin = consensus_anchors(cons_fin)
711
+ cons_cv = cv_metric(cons_fin[:2000])
712
+ print(f" Consensus CV: {cons_cv:.4f}")
713
+
714
+ torch.manual_seed(hash(name) % 2**32)
715
+ m = PatchworkClassifier(init_a=anc_fin).to(DEVICE)
716
+ train_distilled(m, tr_final, lb_final, cons_fin, tag=f"[{name}] ")
717
+ acc, cv, ta = eval_model(m, val_imgs, val_labels)
718
+ all_results[name] = {"acc": acc, "cv": cv, "ta": ta, "gen": 4}
719
+ all_models[name] = m
720
+ print(f" β†’ {name}: val={acc:.3f}")
721
+
722
+
723
+ # ══════════════════════════════════════════════════════════════════
724
+ # FINAL FUSION: ALL parents, ALL data
725
+ # ══════════════════════════════════════════════════════════════════
726
+
727
+ print(f"\n{'='*65}")
728
+ print("FINAL FUSION: ALL parents Γ— ALL data")
729
+ print(f"{'='*65}")
730
+
731
+ # Combine all datasets
732
+ print(f"\n Combining datasets A+B+C+D+E...")
733
+ all_datasets = []
734
+ all_labels_combined = []
735
+ for prof, seed in [("A", 42), ("B", 300), ("C", 400), ("D", 500), ("E", 700)]:
736
+ imgs, labs = gen_data(n_per=500, profile=prof, seed=seed)
737
+ all_datasets.append(imgs)
738
+ all_labels_combined.append(labs)
739
+
740
+ tr_all = torch.cat(all_datasets, dim=0).to(DEVICE)
741
+ lb_all = torch.cat(all_labels_combined, dim=0).to(DEVICE)
742
+
743
+ # Shuffle combined
744
+ perm_all = torch.randperm(len(lb_all))
745
+ tr_all = tr_all[perm_all]
746
+ lb_all = lb_all[perm_all]
747
+ print(f" Combined: {len(lb_all):,} samples (5 Γ— 15K)")
748
+
749
+ # ── Raw baseline on all data ──
750
+ print(f"\n ── FUSE_raw (all data, no distillation, no geometry) ──")
751
+ torch.manual_seed(42)
752
+ fuse_raw = PatchworkClassifier(init_a=None).to(DEVICE)
753
+ train_founder(fuse_raw, tr_all, lb_all, use_geo=False, epochs=30, tag="[FRAW] ")
754
+ acc_fr, cv_fr, ta_fr = eval_model(fuse_raw, val_imgs, val_labels)
755
+ all_results["FUSE_raw"] = {"acc": acc_fr, "cv": cv_fr, "ta": ta_fr, "gen": 5}
756
+ print(f" β†’ FUSE_raw: val={acc_fr:.3f}")
757
+
758
+ # ── All-parent consensus on combined data ──
759
+ print(f"\n Extracting ALL parents on combined data...")
760
+ all_parent_names = [n for n in all_models.keys()
761
+ if all_results[n]["acc"] > 0.1] # include everyone who trained
762
+ print(f" Parents ({len(all_parent_names)}): {all_parent_names}")
763
+
764
+ all_parent_embs = []
765
+ for n in all_parent_names:
766
+ all_models[n].eval()
767
+ with torch.no_grad():
768
+ # Encode in chunks to avoid OOM
769
+ chunks = []
770
+ for j in range(0, len(tr_all), 2048):
771
+ chunks.append(all_models[n].encode(tr_all[j:j+2048]).detach())
772
+ all_parent_embs.append(torch.cat(chunks, dim=0))
773
+
774
+ print(f" GPA alignment ({len(all_parent_embs)} models on {len(tr_all):,} samples)...")
775
+ cons_fuse = gpa_consensus(all_parent_embs)
776
+ cons_fuse_cv = cv_metric(cons_fuse[:2000])
777
+ print(f" Consensus CV: {cons_fuse_cv:.4f}")
778
+
779
+ anc_fuse = consensus_anchors(cons_fuse)
780
+ print(f" Anchors: {anc_fuse.shape}")
781
+
782
+ # ── Distilled student on all data from all parents ──
783
+ print(f"\n ── FUSE_distilled (all data, all parents, full pipeline) ──")
784
+ torch.manual_seed(42)
785
+ fuse_student = PatchworkClassifier(init_a=anc_fuse).to(DEVICE)
786
+ train_distilled(fuse_student, tr_all, lb_all, cons_fuse, epochs=30, tag="[FDST] ")
787
+ acc_fd, cv_fd, ta_fd = eval_model(fuse_student, val_imgs, val_labels)
788
+ all_results["FUSE_dist"] = {"acc": acc_fd, "cv": cv_fd, "ta": ta_fd, "gen": 5}
789
+ print(f" β†’ FUSE_distilled: val={acc_fd:.3f}")
790
+
791
+ # Clean up large tensors
792
+ del tr_all, lb_all, all_parent_embs, cons_fuse
793
+ gc.collect(); torch.cuda.empty_cache()
794
+
795
+
796
+ # ══════════════════════════════════════════════════════════════════
797
+ # EVOLUTION SUMMARY
798
+ # ══════════════════════════════════════════════════════════════════
799
+
800
+ print(f"\n\n{'='*65}")
801
+ print("EVOLUTION SUMMARY")
802
+ print(f"{'='*65}")
803
+
804
+ print(f"\n {'Model':<12} {'Gen':>3} {'v_acc':>6} {'cv':>7} "
805
+ f"{'poly':>5} {'curve':>5} {'star':>5} {'struct':>5}")
806
+ print(f" {'-'*58}")
807
+
808
+ for name in sorted(all_results.keys(), key=lambda x: (all_results[x]["gen"], x)):
809
+ r = all_results[name]
810
+ ta = r.get("ta", {})
811
+ print(f" {name:<12} {r['gen']:>3} {r['acc']:>6.3f} {r['cv']:>7.4f} "
812
+ f"{ta.get('polygon',0):>5.2f} {ta.get('curve',0):>5.2f} "
813
+ f"{ta.get('star',0):>5.2f} {ta.get('structure',0):>5.2f}")
814
+
815
+ print(f"\n Per-generation averages:")
816
+ for gen in range(6):
817
+ accs = [r["acc"] for r in all_results.values() if r["gen"] == gen and r["acc"] > 0]
818
+ if accs:
819
+ label = {0: "Gen 0 (founders)", 1: "Gen 1 (first offspring)",
820
+ 2: "Gen 2", 3: "Gen 3", 4: "Gen 4 (triplets)",
821
+ 5: "Gen 5 (FUSION)"}.get(gen, f"Gen {gen}")
822
+ print(f" {label}: mean={np.mean(accs):.3f} best={max(accs):.3f} n={len(accs)}")
823
+
824
+ print(f"\n{'='*65}")
825
+ print("DONE")
826
+ print(f"{'='*65}")