AbstractPhil commited on
Commit
cc4f091
Β·
verified Β·
1 Parent(s): fd243cb

Create constellation.py

Browse files
Files changed (1) hide show
  1. constellation.py +476 -0
constellation.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Constellation β€” Unified Geometric Observer + Interpreter
3
+ ==========================================================
4
+ Configurable implementation covering all validated constellation forms.
5
+
6
+ PROVEN RESULTS:
7
+ Form 1 (Core): 91.5% CIFAR-10 @ 1.6M params, CV=0.2045
8
+ Form 5 (Relay): cos_to_orig=0.994 @ depth 16, 8.4Γ— faster than attn @ 131K
9
+ Hybrid: 88.0% CIFAR-10 @ 23.5M (conv encoder + constellation)
10
+ Scattering v1: 81.9% CIFAR-10 @ 17M (frozen scattering + constellation)
11
+
12
+ UNIVERSAL RULES (empirically validated):
13
+ - SquaredReLU in all constellation paths, never GELU
14
+ - Patchwork: Linear(in, in*2) β†’ SquaredReLU β†’ LN β†’ Linear(in*2, out)
15
+ - Gate init: -3.0 (sigmoid β‰ˆ 0.047) for relay/residual forms
16
+ - SLERP: acos in fp32, everything else in compute dtype
17
+ - Adam, NO weight decay β€” geometry IS regularization
18
+ - InfoNCE is alignment FORCE, Procrustes is REGULARIZER
19
+ - CV loss on the BOTTLENECK, weight 0.001 or below
20
+ - Anchor dropout (30%) prevents collapse in high-anchor configs
21
+
22
+ FORMS:
23
+ Constellation β€” observation + interpretation, configurable
24
+ ConstellationRelay β€” per-token geometric layer with gated residual
25
+
26
+ Usage:
27
+ from constellation import Constellation, ConstellationRelay
28
+
29
+ # Form 1 (Core): single vector per image
30
+ c = Constellation(n_anchors=16, dim=16, n_directions=8,
31
+ d_comp=64, n_phases=3)
32
+ output = c(directions) # (B, 8, 16) β†’ ConstellationOutput
33
+
34
+ # Form 5 (Relay): per-token processing
35
+ r = ConstellationRelay(dim=256, patch_dim=16, n_anchors=16)
36
+ out = r(tokens) # (B, S, 256) β†’ (B, S, 256)
37
+ """
38
+
39
+ import torch
40
+ import torch.nn as nn
41
+ import torch.nn.functional as F
42
+ import math
43
+ from dataclasses import dataclass
44
+ from typing import Optional
45
+
46
+
47
+ # ══════════════════════════════════════════════════════════════════
48
+ # ACTIVATION
49
+ # ══════════════════════════════════════════════════════════════════
50
+
51
+ class SquaredReLU(nn.Module):
52
+ """x β†’ ReLU(x)Β². Proven superior to GELU in all constellation paths."""
53
+ def forward(self, x):
54
+ return F.relu(x) ** 2
55
+
56
+
57
+ # ══════════════════════════════════════════════════════════════════
58
+ # ANCHOR INITIALIZATION
59
+ # ══════════════════════════════════════════════════════════════════
60
+
61
+ def init_anchors_xavier(n, d):
62
+ """Xavier normal β†’ normalize. Near-orthogonal in high-d. Used in Core."""
63
+ w = torch.empty(n, d)
64
+ nn.init.xavier_normal_(w)
65
+ return F.normalize(w, dim=-1)
66
+
67
+
68
+ def init_anchors_orthogonal(n, d):
69
+ """QR decomposition β†’ exact orthonormal basis. Used when n <= d."""
70
+ if n <= d:
71
+ M = torch.randn(d, n)
72
+ Q, _ = torch.linalg.qr(M)
73
+ return Q.T.contiguous()
74
+ else:
75
+ M = torch.randn(d, d)
76
+ Q, _ = torch.linalg.qr(M)
77
+ basis = Q.T
78
+ extra = F.normalize(torch.randn(n - d, d), dim=-1)
79
+ return torch.cat([basis, extra], dim=0)
80
+
81
+
82
+ def init_anchors_repulsion(n, d, iters=200, lr=0.05):
83
+ """QR + iterative repulsion for even coverage beyond d anchors."""
84
+ vecs = init_anchors_orthogonal(n, d)
85
+ vecs = F.normalize(vecs, dim=-1)
86
+ for _ in range(iters):
87
+ sim = vecs @ vecs.T
88
+ sim.fill_diagonal_(-2.0)
89
+ nn_idx = sim.argmax(dim=1)
90
+ vecs = F.normalize(vecs - lr * vecs[nn_idx], dim=-1)
91
+ return vecs
92
+
93
+
94
+ INIT_METHODS = {
95
+ 'xavier': init_anchors_xavier,
96
+ 'orthogonal': init_anchors_orthogonal,
97
+ 'repulsion': init_anchors_repulsion,
98
+ }
99
+
100
+
101
+ # ══════════════════════════════════════════════════════════════════
102
+ # OUTPUT
103
+ # ══════════════════════════════════════════════════════════════════
104
+
105
+ @dataclass
106
+ class ConstellationOutput:
107
+ """Full output from constellation forward pass."""
108
+ embedding: torch.Tensor # (B, pw_dim) β€” interpreted observation
109
+ cosines: torch.Tensor # (B, N, A) or (B, N, A*phases)
110
+ distances: torch.Tensor # (B, N, A) or (B, N, A*phases)
111
+ nearest: torch.Tensor # (B, N) β€” collapsed anchor assignment
112
+ directions: torch.Tensor # (B, N, D) β€” input directions on S^(D-1)
113
+ tri_flat: torch.Tensor # (B, tri_dim) β€” flattened triangulation
114
+
115
+
116
+ # ════════════════════════════���═════════════════════════════════════
117
+ # CONSTELLATION β€” observation + interpretation
118
+ # ══════════════════════════════════════════════════════════════════
119
+
120
+ class Constellation(nn.Module):
121
+ """Geometric observer with anchor-aligned interpretation.
122
+
123
+ Anchors on S^(D-1) observe input directions via triangulation.
124
+ Compartments interpret per-anchor observations.
125
+ SLERP phases provide multi-scale angular measurement.
126
+ All coupled through gradient flow.
127
+
128
+ Args:
129
+ n_anchors: reference directions on S^(D-1)
130
+ dim: anchor/direction dimensionality
131
+ n_directions: input directions per sample
132
+ d_comp: hidden dim per compartment
133
+ n_phases: SLERP interpolation phases (1=static, 3=proven default)
134
+ anchor_init: 'xavier', 'orthogonal', or 'repulsion'
135
+ anchor_dropout: fraction of anchors to drop during training (0.3 for soup)
136
+ compartment: 'aligned' (one per anchor) or 'flat' (single patchwork)
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ n_anchors: int,
142
+ dim: int,
143
+ n_directions: int,
144
+ d_comp: int = 64,
145
+ n_phases: int = 3,
146
+ anchor_init: str = 'xavier',
147
+ anchor_dropout: float = 0.0,
148
+ compartment: str = 'aligned',
149
+ ):
150
+ super().__init__()
151
+ self.n_anchors = n_anchors
152
+ self.dim = dim
153
+ self.n_directions = n_directions
154
+ self.d_comp = d_comp
155
+ self.n_phases = n_phases
156
+ self.anchor_dropout = anchor_dropout
157
+ self.compartment_type = compartment
158
+
159
+ # Anchors: home (frozen) + current (learned)
160
+ init_fn = INIT_METHODS[anchor_init]
161
+ home = init_fn(n_anchors, dim)
162
+ self.register_buffer('home', home)
163
+ self.anchors = nn.Parameter(home.clone())
164
+
165
+ # Triangulation dimensions
166
+ if compartment == 'aligned':
167
+ # tri: (B, N, A * phases) β†’ each compartment reads its anchor's column
168
+ self.tri_dim = n_directions * n_anchors * n_phases
169
+ self.embedding_dim = n_anchors * d_comp
170
+
171
+ # One compartment per anchor β€” reads tri[:, :, k] across all phases
172
+ # Input: n_directions * n_phases values per anchor
173
+ comp_in = n_directions * n_phases
174
+ self.compartments = nn.ModuleList([
175
+ nn.Sequential(
176
+ nn.Linear(comp_in, d_comp * 2),
177
+ SquaredReLU(),
178
+ nn.Linear(d_comp * 2, d_comp),
179
+ nn.LayerNorm(d_comp),
180
+ ) for _ in range(n_anchors)
181
+ ])
182
+ elif compartment == 'flat':
183
+ # tri: (B, tri_dim) β†’ single patchwork MLP
184
+ self.tri_dim = n_directions * n_anchors * n_phases
185
+ self.embedding_dim = dim
186
+
187
+ self.patchwork = nn.Sequential(
188
+ nn.Linear(self.tri_dim, self.tri_dim * 2),
189
+ SquaredReLU(),
190
+ nn.LayerNorm(self.tri_dim * 2),
191
+ nn.Linear(self.tri_dim * 2, dim),
192
+ )
193
+ else:
194
+ raise ValueError(f"Unknown compartment type: {compartment}")
195
+
196
+ self._init_weights()
197
+
198
+ def _init_weights(self):
199
+ for m in self.modules():
200
+ if isinstance(m, nn.Linear):
201
+ nn.init.trunc_normal_(m.weight, std=0.02)
202
+ if m.bias is not None:
203
+ nn.init.zeros_(m.bias)
204
+ elif isinstance(m, nn.LayerNorm):
205
+ nn.init.ones_(m.weight)
206
+ nn.init.zeros_(m.bias)
207
+
208
+ def drift(self):
209
+ """Geodesic distance between home and learned anchor positions."""
210
+ h = F.normalize(self.home.float(), dim=-1)
211
+ c = F.normalize(self.anchors.float(), dim=-1)
212
+ return torch.acos((h * c).sum(-1).clamp(-1 + 1e-6, 1 - 1e-6))
213
+
214
+ def at_phase(self, t):
215
+ """SLERP between home and learned positions at phase t ∈ [0, 1]."""
216
+ h = F.normalize(self.home.float(), dim=-1)
217
+ c = F.normalize(self.anchors.float(), dim=-1)
218
+ omega = self.drift().unsqueeze(-1) # (A, 1)
219
+ so = omega.sin().clamp(min=1e-6)
220
+ return torch.sin((1 - t) * omega) / so * h + torch.sin(t * omega) / so * c
221
+
222
+ def _triangulate(self, directions, anchors):
223
+ """(B, N, D) Γ— (A, D) β†’ (B, N, A) cosines and distances."""
224
+ cos = torch.einsum('bnd,ad->bna', directions, anchors)
225
+ return cos, 1.0 - cos
226
+
227
+ def forward(self, directions: torch.Tensor) -> ConstellationOutput:
228
+ """Observe and interpret.
229
+
230
+ Args:
231
+ directions: (B, N, D) β€” L2-normalized to S^(D-1)
232
+
233
+ Returns:
234
+ ConstellationOutput
235
+ """
236
+ B, N, D = directions.shape
237
+
238
+ # Multi-phase triangulation
239
+ phases = torch.linspace(0, 1, self.n_phases, device=directions.device).tolist()
240
+ all_cos = []
241
+ all_dist = []
242
+ for t in phases:
243
+ anchors_t = F.normalize(self.at_phase(t), dim=-1).to(directions.dtype)
244
+
245
+ # Anchor dropout during training
246
+ if self.training and self.anchor_dropout > 0:
247
+ mask = torch.rand(anchors_t.shape[0], device=anchors_t.device) > self.anchor_dropout
248
+ if mask.sum() < 2:
249
+ mask[:2] = True
250
+ anchors_t = anchors_t[mask]
251
+
252
+ cos, dist = self._triangulate(directions, anchors_t)
253
+ all_cos.append(cos)
254
+ all_dist.append(dist)
255
+
256
+ # Stack phases: (B, N, A*phases) if no dropout, variable if dropout
257
+ cos_cat = torch.cat(all_cos, dim=-1)
258
+ dist_cat = torch.cat(all_dist, dim=-1)
259
+
260
+ # Nearest anchor (from phase 0, no dropout)
261
+ anchors_0 = F.normalize(self.at_phase(0.0), dim=-1).to(directions.dtype)
262
+ cos_0 = torch.einsum('bnd,ad->bna', directions, anchors_0)
263
+ nearest = cos_0.max(dim=-1).indices
264
+
265
+ # Interpret
266
+ if self.compartment_type == 'aligned' and not (self.training and self.anchor_dropout > 0):
267
+ # dist_cat: (B, N, A * n_phases)
268
+ # Reshape to (B, N, n_phases, A) then (B, A, N * n_phases)
269
+ A = self.n_anchors
270
+ dist_reshape = dist_cat.reshape(B, N, self.n_phases, A)
271
+ # For compartment k: gather distances to anchor k across all directions and phases
272
+ # dist_reshape[:, :, :, k] β†’ (B, N, n_phases) β†’ flatten β†’ (B, N*n_phases)
273
+ parts = []
274
+ for k in range(A):
275
+ comp_input = dist_reshape[:, :, :, k].reshape(B, N * self.n_phases)
276
+ parts.append(self.compartments[k](comp_input))
277
+ embedding = torch.cat(parts, dim=-1) # (B, A * d_comp)
278
+ elif self.compartment_type == 'flat' or (self.training and self.anchor_dropout > 0):
279
+ tri_flat = dist_cat.reshape(B, -1)
280
+ if self.compartment_type == 'flat':
281
+ embedding = self.patchwork(tri_flat)
282
+ else:
283
+ # Fallback for aligned + dropout: pad and use compartments
284
+ # This is a training-only path
285
+ embedding = torch.zeros(B, self.embedding_dim,
286
+ device=directions.device, dtype=directions.dtype)
287
+ # Use flat mean as fallback during dropout
288
+ for k in range(self.n_anchors):
289
+ comp_in_size = self.n_directions * self.n_phases
290
+ if tri_flat.shape[1] >= comp_in_size:
291
+ chunk = tri_flat[:, :comp_in_size]
292
+ else:
293
+ chunk = F.pad(tri_flat, (0, comp_in_size - tri_flat.shape[1]))
294
+ embedding[:, k * self.d_comp:(k + 1) * self.d_comp] = self.compartments[k](chunk)
295
+ else:
296
+ tri_flat = dist_cat.reshape(B, -1)
297
+ embedding = self.patchwork(tri_flat)
298
+
299
+ tri_flat = dist_cat.reshape(B, -1)
300
+
301
+ return ConstellationOutput(
302
+ embedding=embedding,
303
+ cosines=cos_cat,
304
+ distances=dist_cat,
305
+ nearest=nearest,
306
+ directions=directions,
307
+ tri_flat=tri_flat,
308
+ )
309
+
310
+
311
+ # ══════════════════════════════════════════════════════════════════
312
+ # CONSTELLATION RELAY β€” Form 5 (per-token geometric layer)
313
+ # ══════════════════════════════════════════════════════════════════
314
+
315
+ class ConstellationRelay(nn.Module):
316
+ """Per-token geometric processing layer with gated residual.
317
+
318
+ Replaces attention as a per-token processing layer.
319
+ O(S) complexity. No cross-token interaction.
320
+ Preserves 99.4% cosine similarity to input at depth 16.
321
+
322
+ Pipeline:
323
+ LayerNorm β†’ chunk D into patches β†’ L2 norm per patch
324
+ β†’ Constellation observation + interpretation
325
+ β†’ Project back to D β†’ gated residual
326
+
327
+ Args:
328
+ dim: token dimension (must be divisible by patch_dim)
329
+ patch_dim: dimension per patch subspace (default 16)
330
+ n_anchors: anchors per patch subspace
331
+ d_comp: hidden dim per compartment
332
+ n_phases: SLERP phases
333
+ gate_init: initial gate bias (default -3.0 β†’ sigmoid β‰ˆ 0.047)
334
+ anchor_init: initialization method
335
+ """
336
+
337
+ def __init__(
338
+ self,
339
+ dim: int,
340
+ patch_dim: int = 16,
341
+ n_anchors: int = 16,
342
+ d_comp: int = 64,
343
+ n_phases: int = 3,
344
+ gate_init: float = -3.0,
345
+ anchor_init: str = 'xavier',
346
+ ):
347
+ super().__init__()
348
+ assert dim % patch_dim == 0
349
+ self.dim = dim
350
+ self.patch_dim = patch_dim
351
+ self.n_patches = dim // patch_dim
352
+
353
+ self.norm = nn.LayerNorm(dim)
354
+
355
+ # Constellation operates on (B*S, n_patches, patch_dim)
356
+ self.constellation = Constellation(
357
+ n_anchors=n_anchors,
358
+ dim=patch_dim,
359
+ n_directions=self.n_patches,
360
+ d_comp=d_comp,
361
+ n_phases=n_phases,
362
+ anchor_init=anchor_init,
363
+ compartment='aligned',
364
+ )
365
+
366
+ # Project constellation embedding back to token dim
367
+ self.proj = nn.Linear(self.constellation.embedding_dim, dim)
368
+
369
+ # Gated residual β€” init at -3.0 so gate starts near 0
370
+ self.gate = nn.Parameter(torch.full((dim,), gate_init))
371
+
372
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
373
+ """
374
+ x: (B, S, D) or (B, D)
375
+ Returns: same shape as input
376
+ """
377
+ squeeze = False
378
+ if x.dim() == 2:
379
+ x = x.unsqueeze(1)
380
+ squeeze = True
381
+
382
+ B, S, D = x.shape
383
+ residual = x
384
+
385
+ # Normalize
386
+ h = self.norm(x)
387
+
388
+ # Chunk into patches and normalize to S^(patch_dim-1)
389
+ h_flat = h.reshape(B * S, self.n_patches, self.patch_dim)
390
+ h_flat = F.normalize(h_flat, dim=-1)
391
+
392
+ # Constellation: observe + interpret
393
+ output = self.constellation(h_flat)
394
+
395
+ # Project back to token dim
396
+ update = self.proj(output.embedding) # (B*S, D)
397
+ update = update.reshape(B, S, D)
398
+
399
+ # Gated residual
400
+ g = torch.sigmoid(self.gate)
401
+ out = residual + g * update
402
+
403
+ if squeeze:
404
+ out = out.squeeze(1)
405
+ return out
406
+
407
+
408
+ # ══════════════════════════════════════════════════════════════════
409
+ # GEOMETRIC OPS β€” measurement tools
410
+ # ══════════════════════════════════════════════════════════════════
411
+
412
+ class GeometricOps:
413
+ """Static geometric utilities for constellation monitoring and loss."""
414
+
415
+ @staticmethod
416
+ def cayley_menger_vol2(points):
417
+ """Squared simplex volume. points: (B, N, D) β†’ (B,)."""
418
+ B, N, D = points.shape
419
+ gram = torch.bmm(points, points.transpose(1, 2))
420
+ norms = torch.diagonal(gram, dim1=1, dim2=2)
421
+ d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
422
+ d2 = F.relu(d2)
423
+ cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=points.dtype)
424
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
425
+ k = N - 1
426
+ sign = (-1.0) ** (k + 1)
427
+ fact = math.factorial(k)
428
+ return sign * torch.linalg.det(cm.float()).to(points.dtype) / ((2 ** k) * (fact ** 2))
429
+
430
+ @staticmethod
431
+ def cv_metric(emb, n_samples=200, n_points=5):
432
+ """Non-differentiable CV for monitoring. Target band: 0.20–0.23."""
433
+ vols = []
434
+ for _ in range(n_samples):
435
+ idx = torch.randperm(emb.shape[0])[:n_points]
436
+ v2 = GeometricOps.cayley_menger_vol2(emb[idx].unsqueeze(0))
437
+ if v2[0] > 1e-20:
438
+ vols.append(v2[0].sqrt())
439
+ if len(vols) < 10:
440
+ return 0.0
441
+ vols_t = torch.stack(vols)
442
+ return (vols_t.std() / (vols_t.mean() + 1e-8)).item()
443
+
444
+ @staticmethod
445
+ def cv_loss(emb, target=0.22, n_samples=100, n_points=5):
446
+ """Differentiable CV loss. Weight: 0.001 or below."""
447
+ vols = []
448
+ for _ in range(n_samples):
449
+ idx = torch.randperm(min(emb.shape[0], 512))[:n_points]
450
+ v2 = GeometricOps.cayley_menger_vol2(emb[idx].unsqueeze(0))
451
+ if v2[0] > 1e-20:
452
+ vols.append(v2[0].sqrt())
453
+ if len(vols) < 5:
454
+ return torch.tensor(0.0, device=emb.device)
455
+ vols_t = torch.stack(vols)
456
+ cv = vols_t.std() / (vols_t.mean() + 1e-8)
457
+ return (cv - target).pow(2)
458
+
459
+ @staticmethod
460
+ def anchor_spread_loss(anchors, target_cos=0.0):
461
+ """Repulsion loss keeping anchors spread on the sphere."""
462
+ a = F.normalize(anchors, dim=-1)
463
+ sim = a @ a.T
464
+ mask = ~torch.eye(a.shape[0], dtype=torch.bool, device=a.device)
465
+ return F.relu(sim[mask] - target_cos).mean()
466
+
467
+ @staticmethod
468
+ def diagnostics(output: ConstellationOutput, n_anchors: int) -> dict:
469
+ """Compute diagnostic metrics."""
470
+ diag = {}
471
+ diag['n_active'] = output.nearest.flatten().unique().numel()
472
+ counts = torch.bincount(output.nearest.flatten(), minlength=n_anchors).float()
473
+ diag['anchor_util_std'] = counts.std().item()
474
+ diag['nearest_cos'] = output.cosines[:, :, :n_anchors].max(dim=-1).values.mean().item()
475
+ diag['mean_tri'] = output.distances.mean().item()
476
+ return diag