AbstractPhil commited on
Commit
1429fbb
Β·
verified Β·
1 Parent(s): 394b68b

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +1522 -0
model.py ADDED
@@ -0,0 +1,1522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geometric Transformer β€” CM-Validated Pipeline
3
+ ==================================================
4
+ Dual-stream transformer with CM-gated constellation observation,
5
+ quaternion composition, and per-layer Cayley alignment.
6
+
7
+ CM-validated pipeline changes:
8
+ - CM validity gate between association and curation (AnchorGate)
9
+ - 4-stream PositionGeometricContext: anchor + structural + history + quality
10
+ - CM-conditioned geometric residual accumulation (replaces blind learned gate)
11
+ - Built-in geometric regularization (CV target + anchor spread)
12
+ - Decomposed observer pipeline: association β†’ CM gate β†’ gated curation
13
+
14
+ Pipeline per layer:
15
+ 1. ManifoldProjection: h_i β†’ emb_i on S^(d-1) per position
16
+ 2. ConstellationAssociation: emb_i β†’ raw triangulation, cos, assignment
17
+ 3. CMValidatedGate: per-anchor CM validity β†’ gate_values (B*L, A)
18
+ 4. Gated curation: patchwork reads tri * gate_values (validated only)
19
+ 5. PositionGeometricContext: 4 streams β†’ FiLM context (B, L, context_dim)
20
+ 6. ContentAttention (Stream A): standard MHA
21
+ 7. GeometricAttention (Stream B): FiLM(Q,K | geo_ctx), V pure
22
+ 8. CayleyOrthogonal: align B β†’ A basis
23
+ 9. QuaternionCompose: w=A, i=aligned_B, j=A-B, k=A*B
24
+ 10. Decode + gated residual
25
+ 11. CM-conditioned geometric residual write
26
+
27
+ Geometric regularization (call model.geometric_losses() during training):
28
+ - CV loss: anchor CV β†’ pentachoron band (0.20-0.23)
29
+ - Spread loss: prevent anchor collapse (penalize positive cosine)
30
+ These maintain the constellation in the regime where CM validation works.
31
+
32
+ Design principles from Ryan Spearman (ρ=0.309, 76/84 wins):
33
+ - FiLM on Q,K ONLY β€” geometry routes attention, V stays pure
34
+ - FiLM on individual arms BEFORE composition, not after
35
+ - Quaternion algebra as structural regularizer (non-commutative coupling)
36
+ - CayleyOrthogonal guarantees pure rotation (det=1 always)
37
+ - Never global average pool β€” per-position geometric context
38
+
39
+ Author: AbstractPhil + Claude Opus 4.6
40
+ License: Apache 2.0
41
+ """
42
+
43
+ import math
44
+ import torch
45
+ import torch.nn as nn
46
+ import torch.nn.functional as F
47
+
48
+ # ═══════════════════════════════════════════════════════════════════════════════
49
+ # GEOLIP IMPORTS β€” real components, not reimplementations
50
+ # ═══════════════════════════════════════════════════════════════════════════════
51
+
52
+ try:
53
+ from geolip_core.core.associate.constellation import (
54
+ ConstellationObserver, ConstellationAssociation, ConstellationCuration,
55
+ Constellation, init_anchors_repulsion,
56
+ )
57
+ from geolip_core.core.curate.gate import AnchorGate as _GeolipAnchorGate
58
+ from geolip_core.pipeline.observer import (
59
+ TorchComponent, BaseTower, Input, Curation, Distinction,
60
+ )
61
+ from geolip_core.core.distinguish.losses import (
62
+ observer_loss as _geolip_observer_loss,
63
+ ce_loss_paired as _geolip_ce_loss_paired,
64
+ cv_loss as _geolip_cv_loss,
65
+ spread_loss as _geolip_spread_loss,
66
+ )
67
+ _HAS_GEOLIP = True
68
+ except ImportError:
69
+ _HAS_GEOLIP = False
70
+
71
+ # ── Fallback stubs ──
72
+ class TorchComponent(nn.Module):
73
+ def __init__(self, name=None, **kwargs):
74
+ super().__init__()
75
+ self._component_name = name or self.__class__.__name__
76
+
77
+ class BaseTower(nn.Module):
78
+ def __init__(self, name=None, **kwargs):
79
+ super().__init__()
80
+ self._tower_name = name or self.__class__.__name__
81
+ self._components = nn.ModuleDict()
82
+ self._cache = {}
83
+
84
+ def attach(self, name, module):
85
+ if isinstance(module, nn.Module):
86
+ self._components[name] = module
87
+ return self
88
+
89
+ def has(self, name):
90
+ return name in self._components
91
+
92
+ def __getitem__(self, key):
93
+ return self._components[key]
94
+
95
+ def cache_set(self, key, value):
96
+ self._cache[key] = value
97
+
98
+ def cache_get(self, key, default=None):
99
+ return self._cache.get(key, default)
100
+
101
+ def cache_clear(self):
102
+ self._cache.clear()
103
+
104
+ Input = TorchComponent
105
+ Curation = TorchComponent
106
+ Distinction = TorchComponent
107
+
108
+ class Constellation(nn.Module):
109
+ """Learned anchors on S^(d-1). Triangulates input embeddings."""
110
+ def __init__(self, n_anchors, dim, anchor_drop=0.0, anchor_init='repulsion'):
111
+ super().__init__()
112
+ self.n_anchors = n_anchors
113
+ self.dim = dim
114
+ anchors = torch.randn(n_anchors, dim)
115
+ anchors = F.normalize(anchors, dim=-1)
116
+ for _ in range(200):
117
+ sim = anchors @ anchors.T
118
+ sim.fill_diagonal_(-2.0)
119
+ anchors = F.normalize(anchors - 0.05 * anchors[sim.argmax(dim=1)], dim=-1)
120
+ self.anchors = nn.Parameter(anchors)
121
+
122
+ def forward(self, emb, training=False):
123
+ anchors = F.normalize(self.anchors, dim=-1)
124
+ cos = emb @ anchors.T
125
+ tri = 1.0 - cos
126
+ _, nearest = cos.max(dim=-1)
127
+ return tri, nearest
128
+
129
+ class ConstellationAssociation(TorchComponent):
130
+ """Association through constellation anchors."""
131
+ def __init__(self, dim=256, n_anchors=32, anchor_drop=0.0,
132
+ anchor_init='repulsion', assign_temp=0.1, **kwargs):
133
+ super().__init__(**kwargs)
134
+ self.assign_temp = assign_temp
135
+ self.constellation = Constellation(n_anchors, dim, anchor_drop, anchor_init)
136
+
137
+ @property
138
+ def frame_dim(self):
139
+ return self.constellation.n_anchors
140
+
141
+ def associate(self, emb, **context):
142
+ anchors_n = F.normalize(self.constellation.anchors, dim=-1)
143
+ cos = emb @ anchors_n.T
144
+ tri = 1.0 - cos
145
+ _, nearest = cos.max(dim=-1)
146
+ soft_assign = F.softmax(cos / self.assign_temp, dim=-1)
147
+ mag = context.get('mag', None)
148
+ distances_weighted = tri * mag if mag is not None else tri
149
+ return {
150
+ 'distances': tri, 'distances_weighted': distances_weighted,
151
+ 'cos_to_anchors': cos, 'assignment': soft_assign,
152
+ 'nearest': nearest,
153
+ }
154
+
155
+ def forward(self, emb, **context):
156
+ return self.associate(emb, **context)
157
+
158
+ class Patchwork(nn.Module):
159
+ """Round-robin patchwork compartments."""
160
+ def __init__(self, n_anchors, n_comp=8, d_comp=32, activation='gelu'):
161
+ super().__init__()
162
+ self.n_comp = n_comp
163
+ anchors_per = max(1, n_anchors // n_comp)
164
+ self.compartments = nn.ModuleList([
165
+ nn.Sequential(nn.Linear(anchors_per, d_comp), nn.GELU(), nn.Linear(d_comp, d_comp))
166
+ for _ in range(n_comp)
167
+ ])
168
+ self.output_dim = n_comp * d_comp
169
+ self.anchors_per = anchors_per
170
+
171
+ def forward(self, distances):
172
+ parts = []
173
+ for i, comp in enumerate(self.compartments):
174
+ start = i * self.anchors_per
175
+ end = start + self.anchors_per
176
+ chunk = distances[..., start:end]
177
+ if chunk.shape[-1] < self.anchors_per:
178
+ chunk = F.pad(chunk, (0, self.anchors_per - chunk.shape[-1]))
179
+ parts.append(comp(chunk))
180
+ return torch.cat(parts, dim=-1)
181
+
182
+ class ConstellationCuration(Curation):
183
+ """Curation through patchwork compartments + bridge."""
184
+ def __init__(self, n_anchors=32, dim=256, n_comp=8, d_comp=32,
185
+ activation='gelu', **kwargs):
186
+ super().__init__(**kwargs)
187
+ self.dim = dim
188
+ self.n_anchors = n_anchors
189
+ self.patchwork = Patchwork(n_anchors, n_comp, d_comp, activation)
190
+ pw_dim = self.patchwork.output_dim
191
+ self.bridge = nn.Linear(pw_dim, n_anchors)
192
+ self._feature_dim = n_anchors + pw_dim + dim
193
+
194
+ @property
195
+ def feature_dim(self):
196
+ return self._feature_dim
197
+
198
+ def curate_full(self, association_output, emb=None, **context):
199
+ distances = association_output['distances_weighted']
200
+ assignment = association_output['assignment']
201
+ pw = self.patchwork(distances)
202
+ bridge = self.bridge(pw)
203
+ parts = [assignment, pw]
204
+ if emb is not None:
205
+ parts.append(emb)
206
+ features = torch.cat(parts, dim=-1)
207
+ return {'patchwork': pw, 'bridge': bridge, 'features': features}
208
+
209
+ def forward(self, association_output, emb=None, **context):
210
+ return self.curate_full(association_output, emb=emb, **context)['features']
211
+
212
+ class ConstellationObserver(nn.Module):
213
+ """Composed association + curation."""
214
+ def __init__(self, dim=256, n_anchors=32, n_comp=8, d_comp=32,
215
+ anchor_drop=0.0, anchor_init='repulsion',
216
+ activation='gelu', assign_temp=0.1):
217
+ super().__init__()
218
+ self.association = ConstellationAssociation(
219
+ dim=dim, n_anchors=n_anchors, anchor_drop=anchor_drop,
220
+ anchor_init=anchor_init, assign_temp=assign_temp)
221
+ self.curation = ConstellationCuration(
222
+ n_anchors=n_anchors, dim=dim, n_comp=n_comp,
223
+ d_comp=d_comp, activation=activation)
224
+
225
+ @property
226
+ def constellation(self):
227
+ return self.association.constellation
228
+
229
+ @property
230
+ def patchwork(self):
231
+ return self.curation.patchwork
232
+
233
+ @property
234
+ def feature_dim(self):
235
+ return self.curation.feature_dim
236
+
237
+ def observe(self, emb, **context):
238
+ a_out = self.association(emb, **context)
239
+ c_out = self.curation.curate_full(a_out, emb=emb, **context)
240
+ return {
241
+ 'embedding': emb, 'features': c_out['features'],
242
+ 'triangulation': a_out['distances'],
243
+ 'cos_to_anchors': a_out['cos_to_anchors'],
244
+ 'nearest': a_out['nearest'],
245
+ 'assignment': a_out['assignment'],
246
+ 'patchwork': c_out['patchwork'], 'bridge': c_out['bridge'],
247
+ }
248
+
249
+ def forward(self, emb, **context):
250
+ return self.observe(emb, **context)
251
+
252
+
253
+ # ═══════════════════════════════════════════════════════════════════════════════
254
+ # CAYLEY-MENGER VALIDITY β€” geometric quality measurement
255
+ # ═══════════════════════════════════════════════════════════════════════════════
256
+
257
+ def pairwise_distances_squared(points):
258
+ """Batched pairwise squared distances. (B, N, D) β†’ (B, N, N)."""
259
+ gram = torch.bmm(points, points.transpose(1, 2))
260
+ diag = gram.diagonal(dim1=-2, dim2=-1)
261
+ return diag.unsqueeze(2) + diag.unsqueeze(1) - 2 * gram
262
+
263
+
264
+ def cayley_menger_det(points):
265
+ """Cayley-Menger signed volumeΒ² for simplices. (B, K, D) β†’ (B,).
266
+
267
+ K = number of vertices (k+1 for a k-simplex).
268
+ Sign-corrected: positive = valid non-degenerate simplex.
269
+ """
270
+ B, K, D = points.shape
271
+ d2 = pairwise_distances_squared(points)
272
+ M = torch.zeros(B, K + 1, K + 1, device=points.device, dtype=points.dtype)
273
+ M[:, 0, 1:] = 1.0
274
+ M[:, 1:, 0] = 1.0
275
+ M[:, 1:, 1:] = d2
276
+ raw = torch.linalg.det(M)
277
+ k = K - 1
278
+ sign = (-1.0) ** (k + 1)
279
+ return sign * raw
280
+
281
+
282
+ def anchor_neighborhood_cm(anchors, n_neighbors=3):
283
+ """Precompute per-anchor CM quality from local neighborhood geometry.
284
+
285
+ Position-independent. O(A) determinant computations on small matrices.
286
+ Each anchor forms a simplex with its k nearest neighbor anchors.
287
+ The CM determinant measures local geometric quality β€” high volume means
288
+ the anchor neighborhood is well-conditioned for triangulation.
289
+
290
+ Args:
291
+ anchors: (A, D) normalized anchor positions on S^(d-1)
292
+ n_neighbors: neighbors per simplex
293
+
294
+ Returns:
295
+ quality: (A,) signed log-magnitude CM quality per anchor
296
+ nn_idx: (A, n_neighbors) neighbor indices
297
+ """
298
+ A, D = anchors.shape
299
+ dists = torch.cdist(anchors.unsqueeze(0), anchors.unsqueeze(0)).squeeze(0)
300
+ # Mask self-distances without in-place mutation (compile-safe)
301
+ self_mask = torch.eye(A, device=anchors.device, dtype=anchors.dtype) * 1e12
302
+ dists = dists + self_mask
303
+ _, nn_idx = dists.topk(n_neighbors, largest=False) # (A, n_neighbors)
304
+
305
+ # Build simplices: [anchor_a, neighbor_1, ..., neighbor_k] per anchor
306
+ K = n_neighbors + 1
307
+ simplices = torch.zeros(A, K, D, device=anchors.device, dtype=anchors.dtype)
308
+ simplices[:, 0] = anchors
309
+ for j in range(n_neighbors):
310
+ simplices[:, j + 1] = anchors[nn_idx[:, j]]
311
+
312
+ dets = cayley_menger_det(simplices) # (A,)
313
+ sign = dets.sign()
314
+ log_mag = torch.log(dets.abs() + 1e-12)
315
+ return sign * log_mag, nn_idx
316
+
317
+
318
+ # ═══════════════════════════════════════════════════════════════════════════════
319
+ # CM VALIDATED GATE β€” efficient anchor gating for transformer scale
320
+ # ═══════════════════════════════════════════════════════════════════════════════
321
+
322
+ class CMValidatedGate(nn.Module):
323
+ """Anchor gate based on Cayley-Menger validity.
324
+
325
+ Efficient for transformer scale: anchor CM quality is precomputed O(AΒ²),
326
+ then combined with per-position proximity features through a learned gate.
327
+
328
+ The gate starts OPEN (bias=+2, sigmoidβ‰ˆ0.88) and learns to CLOSE on
329
+ geometrically invalid configurations. Architecture-before-loss: the gate
330
+ suppresses degenerate measurements structurally, not through a loss signal.
331
+
332
+ Gate features per (position, anchor):
333
+ - anchor_cm_quality: CM volume of anchor's local neighborhood (position-independent)
334
+ - cos_to_anchor: cosine similarity (position-dependent)
335
+ - distance_rank: normalized rank of this anchor by proximity (position-dependent)
336
+
337
+ Args:
338
+ n_anchors: number of constellation anchors
339
+ n_neighbors: neighbors for CM simplex computation
340
+ """
341
+ def __init__(self, n_anchors, n_neighbors=3):
342
+ super().__init__()
343
+ self.n_anchors = n_anchors
344
+ self.n_neighbors = n_neighbors
345
+
346
+ # Learned gate: [cm_quality, cos_sim, dist_rank] β†’ scalar gate
347
+ self.gate_proj = nn.Sequential(
348
+ nn.Linear(3, 16),
349
+ nn.GELU(),
350
+ nn.Linear(16, 1),
351
+ )
352
+ # Init OPEN β€” learn to close. sigmoid(2.0) β‰ˆ 0.88
353
+ nn.init.zeros_(self.gate_proj[2].weight)
354
+ nn.init.constant_(self.gate_proj[2].bias, 2.0)
355
+
356
+ def forward(self, embedding, anchors, tri):
357
+ """Compute per-(position, anchor) gate values.
358
+
359
+ Args:
360
+ embedding: (N, D) β€” positions on S^(d-1), where N = B*L
361
+ anchors: (A, D) β€” normalized anchor positions (DETACHED by caller)
362
+ tri: (N, A) β€” triangulation distances (1 - cos)
363
+
364
+ Returns:
365
+ gate_values: (N, A) in [0, 1] β€” per-anchor validity gate
366
+ gate_info: dict with diagnostics
367
+ """
368
+ N, A = tri.shape
369
+
370
+ # ── Anchor CM quality: position-independent, O(AΒ²) ──
371
+ with torch.no_grad():
372
+ anchor_cm, nn_idx = anchor_neighborhood_cm(anchors, self.n_neighbors)
373
+ # Normalize to ~ [-1, 1]
374
+ cm_std = anchor_cm.std().clamp(min=1e-8)
375
+ anchor_cm_norm = (anchor_cm - anchor_cm.mean()) / cm_std
376
+
377
+ # ── Per-position features ──
378
+ cos_sim = 1.0 - tri # (N, A)
379
+
380
+ # Distance rank: 0=nearest, 1=farthest
381
+ ranks = tri.argsort(dim=-1).argsort(dim=-1).float()
382
+ ranks = ranks / max(A - 1, 1)
383
+
384
+ # ── Gate features: (N, A, 3) ──
385
+ features = torch.stack([
386
+ anchor_cm_norm.unsqueeze(0).expand(N, -1),
387
+ cos_sim,
388
+ ranks,
389
+ ], dim=-1)
390
+
391
+ gate_values = torch.sigmoid(self.gate_proj(features).squeeze(-1))
392
+
393
+ # ── Diagnostics (no .item() β€” compile-safe) ──
394
+ with torch.no_grad():
395
+ active = (gate_values > 0.5).float().sum(-1).mean()
396
+ cm_positive_frac = (anchor_cm > 0).float().mean()
397
+ gate_mean = gate_values.mean()
398
+
399
+ gate_info = {
400
+ 'active': active,
401
+ 'gate_mean': gate_mean,
402
+ 'cm_positive_frac': cm_positive_frac,
403
+ 'anchor_cm': anchor_cm.detach(),
404
+ }
405
+
406
+ return gate_values, gate_info
407
+
408
+
409
+ # ═══════════════════════════════════════════════════════════════════════════════
410
+ # INFONCE MEMORY BANK β€” contrastive pressure on geometric residual
411
+ # ═══════════════════════════════════════════════════════════════════════════════
412
+
413
+ class GeoResidualBank(nn.Module):
414
+ """Cross-stream contrastive memory bank (CLIP-style).
415
+
416
+ Aligns content (Stream A CLS) and geometry (geo_residual CLS)
417
+ through contrastive learning. Same sample's content and geometry
418
+ should match; different samples' should not.
419
+
420
+ Bank stores projected geo_residual keys from recent batches.
421
+ Query is projected content CLS from current batch.
422
+ Positive pair: (content_i, geometry_i) from same sample.
423
+ Negatives: geometry from bank.
424
+
425
+ Gradient flows through BOTH streams:
426
+ - Content CLS β†’ transformer β†’ input (learns distinctive content)
427
+ - Geo residual CLS β†’ geo_proj β†’ patchwork β†’ CM gate β†’ constellation
428
+ (learns to observe what content finds relevant)
429
+
430
+ Args:
431
+ bank_size: number of entries in the queue
432
+ proj_dim: shared projection dimension for content and geometry
433
+ temperature: InfoNCE temperature
434
+ """
435
+ def __init__(self, proj_dim, bank_size=4096, temperature=0.1):
436
+ super().__init__()
437
+ self.proj_dim = proj_dim
438
+ self.bank_size = bank_size
439
+ self.temperature = temperature
440
+
441
+ # Queue of projected geo_residual keys
442
+ self.register_buffer('queue', torch.randn(bank_size, proj_dim))
443
+ self.queue = F.normalize(self.queue, dim=-1)
444
+ self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))
445
+
446
+ @torch.no_grad()
447
+ def enqueue(self, keys):
448
+ """Add projected geo keys to queue. Called AFTER backward.
449
+ Args:
450
+ keys: (B, proj_dim) normalized projected geo_residual CLS
451
+ """
452
+ B = keys.shape[0]
453
+ ptr = int(self.queue_ptr.item())
454
+ if ptr + B <= self.bank_size:
455
+ self.queue[ptr:ptr + B] = keys
456
+ else:
457
+ overflow = (ptr + B) - self.bank_size
458
+ self.queue[ptr:] = keys[:B - overflow]
459
+ self.queue[:overflow] = keys[B - overflow:]
460
+ self.queue_ptr[0] = (ptr + B) % self.bank_size
461
+
462
+ def forward(self, content_proj, geo_proj):
463
+ """Cross-stream InfoNCE: content queries vs geometry keys.
464
+
465
+ Args:
466
+ content_proj: (B, proj_dim) β€” projected content CLS (LIVE, has grad)
467
+ geo_proj: (B, proj_dim) β€” projected geo_residual CLS (LIVE, has grad)
468
+
469
+ Returns:
470
+ loss: scalar InfoNCE loss
471
+ acc: top-1 retrieval accuracy (diagnostic)
472
+ """
473
+ q = F.normalize(content_proj, dim=-1) # (B, D)
474
+ k_pos = F.normalize(geo_proj, dim=-1) # (B, D) β€” positive keys
475
+ k_neg = self.queue.clone().detach() # (K, D) β€” negative keys from bank
476
+
477
+ # Positive logits: each content matches its own geometry
478
+ pos_logits = (q * k_pos).sum(dim=-1, keepdim=True) / self.temperature # (B, 1)
479
+
480
+ # Negative logits: each content vs all bank geometry
481
+ neg_logits = q @ k_neg.T / self.temperature # (B, K)
482
+
483
+ # InfoNCE: positive is column 0
484
+ logits = torch.cat([pos_logits, neg_logits], dim=1) # (B, 1+K)
485
+ labels = torch.zeros(q.shape[0], dtype=torch.long, device=q.device)
486
+
487
+ loss = F.cross_entropy(logits, labels)
488
+
489
+ with torch.no_grad():
490
+ acc = (logits.argmax(dim=1) == 0).float().mean()
491
+
492
+ return loss, acc
493
+
494
+
495
+ # ═══════════════════════════════════════════════════════════════════════════════
496
+ # PROVEN COMPONENTS β€” from Ryan Spearman (unchanged, tested)
497
+ # ═══════════════════════════════════════════════════════════════════════════════
498
+
499
+ class FiLMLayer(TorchComponent):
500
+ """Feature-wise Linear Modulation. Proven in Ryan Spearman.
501
+ Identity-initialized: Ξ³=1, Ξ²=0 at init.
502
+ """
503
+ def __init__(self, name, feature_dim, context_dim):
504
+ super().__init__(name)
505
+ self.to_gamma = nn.Linear(context_dim, feature_dim)
506
+ self.to_beta = nn.Linear(context_dim, feature_dim)
507
+ nn.init.zeros_(self.to_gamma.weight); nn.init.ones_(self.to_gamma.bias)
508
+ nn.init.zeros_(self.to_beta.weight); nn.init.zeros_(self.to_beta.bias)
509
+
510
+ def forward(self, x, ctx):
511
+ return self.to_gamma(ctx) * x + self.to_beta(ctx)
512
+
513
+
514
+ class CayleyOrthogonal(TorchComponent):
515
+ """Guaranteed SO(d) rotation via Cayley map. det(Q) = 1 always."""
516
+ def __init__(self, name, dim):
517
+ super().__init__(name)
518
+ self.dim = dim
519
+ self.A_upper = nn.Parameter(torch.zeros(dim * (dim - 1) // 2) * 0.01)
520
+ idx = torch.triu_indices(dim, dim, offset=1)
521
+ self.register_buffer('_triu_row', idx[0], persistent=False)
522
+ self.register_buffer('_triu_col', idx[1], persistent=False)
523
+ self.register_buffer('_eye', torch.eye(dim), persistent=False)
524
+
525
+ def get_rotation(self):
526
+ d = self.dim
527
+ A = torch.zeros(d, d, device=self.A_upper.device, dtype=self.A_upper.dtype)
528
+ A[self._triu_row, self._triu_col] = self.A_upper
529
+ A = A - A.T
530
+ return torch.linalg.solve(self._eye + A, self._eye - A)
531
+
532
+ def forward(self, x):
533
+ return x @ self.get_rotation().T
534
+
535
+
536
+ def quaternion_multiply_batched(q1, q2):
537
+ """Hamilton product on (B, 4, D) tensors. Fully vectorized."""
538
+ w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3]
539
+ w2, x2, y2, z2 = q2[:, 0], q2[:, 1], q2[:, 2], q2[:, 3]
540
+ return torch.stack([
541
+ w1*w2 - x1*x2 - y1*y2 - z1*z2,
542
+ w1*x2 + x1*w2 + y1*z2 - z1*y2,
543
+ w1*y2 - x1*z2 + y1*w2 + z1*x2,
544
+ w1*z2 + x1*y2 - y1*x2 + z1*w2,
545
+ ], dim=1)
546
+
547
+
548
+ class QuaternionCompose(TorchComponent):
549
+ """Four-arm Hamilton product composition. Proven in GeoQuat head.
550
+ Fully vectorized: single batched Hamilton product, no Python loops.
551
+ """
552
+ def __init__(self, name, input_dim, quat_dim=64):
553
+ super().__init__(name)
554
+ self.quat_dim = quat_dim
555
+ self.proj_w = nn.Linear(input_dim, quat_dim)
556
+ self.proj_i = nn.Linear(input_dim, quat_dim)
557
+ self.proj_j = nn.Linear(input_dim, quat_dim)
558
+ self.proj_k = nn.Linear(input_dim, quat_dim)
559
+ self.rotation = nn.Parameter(torch.randn(1, 4, quat_dim) * 0.1)
560
+
561
+ @property
562
+ def output_dim(self):
563
+ return self.quat_dim * 4
564
+
565
+ def forward(self, arm_w, arm_i, arm_j, arm_k):
566
+ shape = arm_w.shape[:-1]
567
+ D = arm_w.shape[-1]
568
+ flat = arm_w.dim() > 2
569
+ if flat:
570
+ arm_w = arm_w.reshape(-1, D); arm_i = arm_i.reshape(-1, D)
571
+ arm_j = arm_j.reshape(-1, D); arm_k = arm_k.reshape(-1, D)
572
+ q = torch.stack([self.proj_w(arm_w), self.proj_i(arm_i),
573
+ self.proj_j(arm_j), self.proj_k(arm_k)], dim=1)
574
+ q = q / (q.norm(dim=1, keepdim=True) + 1e-8)
575
+ r = self.rotation.expand(q.shape[0], -1, -1)
576
+ r = r / (r.norm(dim=1, keepdim=True) + 1e-8)
577
+ composed = quaternion_multiply_batched(r, q)
578
+ composed = composed.reshape(q.shape[0], -1)
579
+ if flat:
580
+ composed = composed.reshape(*shape, -1)
581
+ return composed
582
+
583
+
584
+ # ═════════��═════════════════════════════════════════════════════════════════════
585
+ # TRANSFORMER-SPECIFIC COMPONENTS
586
+ # ═══════════════════════════════════════════════════════════════════════════════
587
+
588
+ class ManifoldProjection(TorchComponent):
589
+ """Input stage: project transformer hidden states to S^(d-1).
590
+ Per-position, per-layer. L2-normalized to unit hypersphere.
591
+ """
592
+ def __init__(self, name, d_model, manifold_dim):
593
+ super().__init__(name)
594
+ self.proj = nn.Linear(d_model, manifold_dim)
595
+ self.norm = nn.LayerNorm(manifold_dim)
596
+
597
+ def forward(self, hidden_states):
598
+ h = self.norm(self.proj(hidden_states))
599
+ return F.normalize(h, dim=-1)
600
+
601
+
602
+ class PositionGeometricContext(TorchComponent):
603
+ """Curation stage: 4-stream fusion β†’ FiLM context.
604
+
605
+ Four streams:
606
+ anchor: cos_to_anchors + assignment + triangulation β€” WHERE on the manifold
607
+ structural: patchwork + embedding β€” WHAT the local geometry looks like
608
+ history: geo_residual from previous layers β€” WHAT prior layers observed
609
+ quality: CM gate values per anchor β€” HOW TRUSTWORTHY is this observation
610
+
611
+ The quality stream gives FiLM direct knowledge of which anchors formed
612
+ valid simplices. This is not a scalar β€” the full (N, A) gate profile
613
+ tells the context WHICH directions on the manifold are reliable.
614
+ """
615
+ def __init__(self, name, n_anchors, pw_dim, manifold_dim, context_dim):
616
+ super().__init__(name)
617
+ self.context_dim = context_dim
618
+ self.pw_dim = pw_dim
619
+
620
+ # WHERE on the manifold
621
+ self.anchor_mlp = nn.Sequential(
622
+ nn.Linear(n_anchors * 3, context_dim), nn.GELU(), nn.LayerNorm(context_dim))
623
+ # WHAT the local geometry looks like
624
+ self.struct_mlp = nn.Sequential(
625
+ nn.Linear(pw_dim + manifold_dim, context_dim), nn.GELU(), nn.LayerNorm(context_dim))
626
+ # WHAT prior layers observed
627
+ self.history_mlp = nn.Sequential(
628
+ nn.Linear(pw_dim, context_dim), nn.GELU(), nn.LayerNorm(context_dim))
629
+ # HOW TRUSTWORTHY β€” full per-anchor gate profile
630
+ self.quality_mlp = nn.Sequential(
631
+ nn.Linear(n_anchors, context_dim), nn.GELU(), nn.LayerNorm(context_dim))
632
+
633
+ # Fuse 4 streams
634
+ self.fuse = nn.Sequential(
635
+ nn.Linear(context_dim * 4, context_dim), nn.GELU(), nn.LayerNorm(context_dim))
636
+
637
+ def forward(self, obs_dict, gate_values=None, geo_residual=None):
638
+ """
639
+ Args:
640
+ obs_dict: from decomposed association + gated curation
641
+ gate_values: (N, A) CM gate values per anchor, or None
642
+ geo_residual: (N, pw_dim) accumulated context, or None for first layer
643
+ Returns:
644
+ (N, context_dim) geometric context for FiLM
645
+ """
646
+ anchor_feats = torch.cat([
647
+ obs_dict['cos_to_anchors'],
648
+ obs_dict['assignment'],
649
+ obs_dict['triangulation'],
650
+ ], dim=-1)
651
+ struct_feats = torch.cat([
652
+ obs_dict['patchwork'],
653
+ obs_dict['embedding'],
654
+ ], dim=-1)
655
+
656
+ a = self.anchor_mlp(anchor_feats)
657
+ s = self.struct_mlp(struct_feats)
658
+ h = self.history_mlp(geo_residual) if geo_residual is not None else torch.zeros_like(a)
659
+ q = self.quality_mlp(gate_values) if gate_values is not None else torch.zeros_like(a)
660
+
661
+ return self.fuse(torch.cat([a, s, h, q], dim=-1))
662
+
663
+
664
+ class GeometricAttention(TorchComponent):
665
+ """Attention with FiLM from curated constellation. Stream B.
666
+ FiLM modulates Q,K BEFORE attention. V stays unmodulated.
667
+ """
668
+ def __init__(self, name, d_model, n_heads=8, context_dim=128, dropout=0.1):
669
+ super().__init__(name)
670
+ self.d_model = d_model
671
+ self.n_heads = n_heads
672
+ self.head_dim = d_model // n_heads
673
+ self.scale = self.head_dim ** -0.5
674
+
675
+ self.w_q = nn.Linear(d_model, d_model)
676
+ self.w_k = nn.Linear(d_model, d_model)
677
+ self.w_v = nn.Linear(d_model, d_model)
678
+ self.w_o = nn.Linear(d_model, d_model)
679
+ self.dropout = nn.Dropout(dropout)
680
+
681
+ self.film_q = FiLMLayer(f'{name}_film_q', d_model, context_dim)
682
+ self.film_k = FiLMLayer(f'{name}_film_k', d_model, context_dim)
683
+ self.norm = nn.LayerNorm(d_model)
684
+
685
+ self.ffn1 = nn.Linear(d_model, d_model * 4)
686
+ self.film_ffn = FiLMLayer(f'{name}_film_ffn', d_model * 4, context_dim)
687
+ self.ffn2 = nn.Linear(d_model * 4, d_model)
688
+ self.ffn_drop = nn.Dropout(dropout)
689
+ self.ffn_norm = nn.LayerNorm(d_model)
690
+
691
+ def forward(self, x, geo_ctx, attn_mask=None, key_padding_mask=None):
692
+ B, L, D = x.shape
693
+ H, HD = self.n_heads, self.head_dim
694
+
695
+ Q = self.film_q(self.w_q(x), geo_ctx)
696
+ K = self.film_k(self.w_k(x), geo_ctx)
697
+ V = self.w_v(x)
698
+
699
+ Q = Q.view(B, L, H, HD).transpose(1, 2)
700
+ K = K.view(B, L, H, HD).transpose(1, 2)
701
+ V = V.view(B, L, H, HD).transpose(1, 2)
702
+
703
+ scores = (Q @ K.transpose(-2, -1)) * self.scale
704
+ if attn_mask is not None:
705
+ scores = scores + attn_mask
706
+ if key_padding_mask is not None:
707
+ scores = scores.masked_fill(
708
+ key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'))
709
+ attn_out = (self.dropout(F.softmax(scores, dim=-1)) @ V)
710
+ attn_out = attn_out.transpose(1, 2).reshape(B, L, D)
711
+ x = self.norm(x + self.w_o(attn_out))
712
+
713
+ h = F.gelu(self.ffn1(x))
714
+ h = self.film_ffn(h, geo_ctx)
715
+ x = self.ffn_norm(x + self.ffn_drop(self.ffn2(h)))
716
+ return x
717
+
718
+
719
+ class ContentAttention(TorchComponent):
720
+ """Standard self-attention. Stream A. No geometric conditioning."""
721
+ def __init__(self, name, d_model, n_heads=8, dropout=0.1):
722
+ super().__init__(name)
723
+ self.attn = nn.MultiheadAttention(
724
+ d_model, n_heads, dropout=dropout, batch_first=True)
725
+ self.norm = nn.LayerNorm(d_model)
726
+ self.ffn = nn.Sequential(
727
+ nn.Linear(d_model, d_model * 4), nn.GELU(),
728
+ nn.Linear(d_model * 4, d_model), nn.Dropout(dropout))
729
+ self.ffn_norm = nn.LayerNorm(d_model)
730
+
731
+ def forward(self, x, attn_mask=None, key_padding_mask=None):
732
+ a, _ = self.attn(x, x, x, attn_mask=attn_mask,
733
+ key_padding_mask=key_padding_mask)
734
+ x = self.norm(x + a)
735
+ x = self.ffn_norm(x + self.ffn(x))
736
+ return x
737
+
738
+
739
+ # ═══════════════════════════════════════════════════════════════════════════════
740
+ # LAYER β€” CM-validated dual-stream with constellation routing
741
+ # ═══════════════════════════════════════════════════════════════════════════════
742
+
743
+ class GeometricTransformerLayer(BaseTower):
744
+ """One layer of the geometric transformer (CM validated).
745
+
746
+ Pipeline per layer:
747
+ 1. ManifoldProjection: h β†’ emb on S^(d-1)
748
+ 2. Association: emb β†’ raw triangulation, cos, assignment
749
+ 3. CMValidatedGate: per-anchor CM validity β†’ gate_values
750
+ 4. Gated curation: patchwork reads tri * gate_values
751
+ 5. PositionGeometricContext: 4 streams β†’ FiLM context
752
+ 6. ContentAttention (Stream A): standard MHA
753
+ 7. GeometricAttention (Stream B): FiLM(Q,K | geo_ctx)
754
+ 8. CayleyOrthogonal: align B β†’ A
755
+ 9. QuaternionCompose: w=A, i=aligned_B, j=A-B, k=A*B
756
+ 10. Decode + gated residual
757
+ 11. CM-conditioned geometric residual accumulation
758
+
759
+ The observer is DECOMPOSED: association and curation are called
760
+ separately with the CM gate inserted between them. The gate
761
+ suppresses degenerate anchor measurements before the patchwork
762
+ reads them. The patchwork only interprets validated geometry.
763
+
764
+ The geometric residual is accumulated using CM quality as the
765
+ write weight β€” no learned gate. Positions with high-quality
766
+ simplex observations contribute more. Positions in degenerate
767
+ regions contribute less.
768
+ """
769
+ def __init__(self, name, d_model, n_heads=8, n_anchors=32,
770
+ manifold_dim=256, n_comp=8, d_comp=32,
771
+ context_dim=128, quat_dim=64, dropout=0.1,
772
+ cm_neighbors=3):
773
+ super().__init__(name)
774
+ self.d_model = d_model
775
+ self.n_anchors = n_anchors
776
+
777
+ # 1. Project to manifold
778
+ self.attach('projection', ManifoldProjection(
779
+ f'{name}_proj', d_model, manifold_dim))
780
+
781
+ # 2. Constellation observer (association + curation β€” called decomposed)
782
+ self.attach('observer', ConstellationObserver(
783
+ dim=manifold_dim, n_anchors=n_anchors,
784
+ n_comp=n_comp, d_comp=d_comp))
785
+
786
+ # 3. CM validated gate β€” between association and curation
787
+ self.attach('cm_gate', CMValidatedGate(
788
+ n_anchors=n_anchors, n_neighbors=cm_neighbors))
789
+
790
+ # 4. Fuse observation into FiLM context (4 streams)
791
+ pw_dim = self['observer'].curation.patchwork.output_dim
792
+ self.attach('context', PositionGeometricContext(
793
+ f'{name}_ctx', n_anchors, pw_dim, manifold_dim, context_dim))
794
+
795
+ # 5. Stream A: content
796
+ self.attach('content', ContentAttention(
797
+ f'{name}_content', d_model, n_heads, dropout))
798
+
799
+ # 6. Stream B: geometric
800
+ self.attach('geometric', GeometricAttention(
801
+ f'{name}_geo', d_model, n_heads, context_dim, dropout))
802
+
803
+ # 7. Cayley rotation: align B β†’ A
804
+ self.attach('rotation', CayleyOrthogonal(f'{name}_cayley', d_model))
805
+
806
+ # 8. Quaternion composition
807
+ self.attach('compose', QuaternionCompose(
808
+ f'{name}_quat', d_model, quat_dim))
809
+
810
+ # 9. Decode + output gate
811
+ self.attach('decode', nn.Sequential(
812
+ nn.Linear(quat_dim * 4, d_model), nn.GELU(), nn.LayerNorm(d_model)))
813
+ self.attach('gate', nn.Sequential(
814
+ nn.Linear(d_model * 2, d_model), nn.Sigmoid()))
815
+
816
+ # 10. Geometric residual projection (no learned gate β€” CM quality decides)
817
+ self._pw_dim = pw_dim
818
+ self.attach('geo_proj', nn.Sequential(
819
+ nn.Linear(pw_dim, pw_dim), nn.LayerNorm(pw_dim)))
820
+
821
+ def forward(self, x, geo_residual=None, attn_mask=None, key_padding_mask=None):
822
+ """
823
+ Args:
824
+ x: (B, L, D) input hidden states
825
+ geo_residual: (B, L, pw_dim) accumulated geometric context,
826
+ or None for first layer
827
+
828
+ Returns:
829
+ x_out: (B, L, D) transformed hidden states
830
+ geo_residual_out: (B, L, pw_dim) updated geometric residual
831
+ geo_state: dict with full geometric state + CM diagnostics
832
+ """
833
+ B, L, D = x.shape
834
+
835
+ # ════ 1. Project to manifold ════
836
+ emb = self['projection'](x) # (B, L, manifold_dim)
837
+ emb_flat = emb.reshape(B * L, -1)
838
+
839
+ # ════ 2. Association β€” raw triangulation ════
840
+ a_out = self['observer'].association(emb_flat)
841
+
842
+ # ════ 3. CM Gate β€” validate anchor measurements ════
843
+ anchors_n = F.normalize(
844
+ self['observer'].association.constellation.anchors, dim=-1)
845
+ gate_values, gate_info = self['cm_gate'](
846
+ emb_flat, anchors_n.detach(), a_out['distances'])
847
+
848
+ # ════ 4. Gated curation β€” patchwork reads validated triangulation ════
849
+ a_out_gated = dict(a_out)
850
+ a_out_gated['distances_weighted'] = a_out['distances'] * gate_values
851
+ c_out = self['observer'].curation.curate_full(a_out_gated, emb=emb_flat)
852
+
853
+ # Build observation dict for context
854
+ obs = {
855
+ 'embedding': emb_flat,
856
+ 'triangulation': a_out['distances'],
857
+ 'cos_to_anchors': a_out['cos_to_anchors'],
858
+ 'assignment': a_out['assignment'],
859
+ 'nearest': a_out['nearest'],
860
+ 'patchwork': c_out['patchwork'],
861
+ 'bridge': c_out['bridge'],
862
+ }
863
+
864
+ # ════ 5. Build FiLM context β€” 4 streams ════
865
+ geo_res_flat = geo_residual.reshape(B * L, -1) if geo_residual is not None else None
866
+ geo_ctx_flat = self['context'](
867
+ obs, gate_values=gate_values, geo_residual=geo_res_flat)
868
+ geo_ctx = geo_ctx_flat.reshape(B, L, -1)
869
+
870
+ # ════ 6. Stream A: content attention ════
871
+ a_out_stream = self['content'](
872
+ x, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
873
+
874
+ # ════ 7. Stream B: geometric attention ════
875
+ b_out = self['geometric'](
876
+ x, geo_ctx, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
877
+
878
+ # ════ 8. Cayley rotation: align B β†’ A ════
879
+ b_aligned = self['rotation'](b_out)
880
+
881
+ # ════ 9. Quaternion composition ════
882
+ composed = self['compose'](
883
+ arm_w=a_out_stream, arm_i=b_aligned,
884
+ arm_j=a_out_stream - b_aligned, arm_k=a_out_stream * b_aligned)
885
+
886
+ # ════ 10. Decode + gated residual ════
887
+ decoded = self['decode'](composed)
888
+ g = self['gate'](torch.cat([x, decoded], dim=-1))
889
+ x_out = g * decoded + (1 - g) * x
890
+
891
+ # ════ 11. CM-conditioned geometric residual accumulation ════
892
+ # CM quality per position: mean gate value across anchors.
893
+ # High quality = position's simplex with anchors is non-degenerate.
894
+ # Low quality = position is in a boundary region or near dead anchors.
895
+ pw_validated = c_out['patchwork'].reshape(B, L, -1)
896
+ cm_quality = gate_values.mean(dim=-1).reshape(B, L, 1) # (B, L, 1)
897
+ geo_update = self['geo_proj'](pw_validated)
898
+
899
+ if geo_residual is None:
900
+ geo_residual_out = cm_quality * geo_update
901
+ else:
902
+ geo_residual_out = geo_residual + cm_quality * geo_update
903
+
904
+ # ════ Build geo_state dict ════
905
+ def _unflatten(t):
906
+ if t is None:
907
+ return None
908
+ if t.dim() == 1:
909
+ return t.reshape(B, L)
910
+ return t.reshape(B, L, *t.shape[1:])
911
+
912
+ geo_state = {
913
+ 'embedding': emb,
914
+ 'geo_ctx': geo_ctx,
915
+ 'triangulation': _unflatten(a_out['distances']),
916
+ 'cos_to_anchors': _unflatten(a_out['cos_to_anchors']),
917
+ 'assignment': _unflatten(a_out['assignment']),
918
+ 'nearest': _unflatten(a_out['nearest']),
919
+ 'patchwork': _unflatten(c_out['patchwork']),
920
+ 'bridge': _unflatten(c_out['bridge']),
921
+ 'gate_values': _unflatten(gate_values),
922
+ 'gate_info': gate_info,
923
+ 'cm_quality': cm_quality,
924
+ 'content': a_out_stream,
925
+ 'geometric': b_out,
926
+ 'composed': composed,
927
+ 'geo_residual': geo_residual_out,
928
+ }
929
+
930
+ return x_out, geo_residual_out, geo_state
931
+
932
+
933
+ # ═══════════════════════════════════════════════════════════════════════════════
934
+ # FULL MODEL β€” stack of layers + geometric regularization
935
+ # ═══════════════════════════════════════════════════════════════════════════════
936
+
937
+ class GeometricTransformer(BaseTower):
938
+ """Geometric Transformer β€” CM-validated dual-stream.
939
+
940
+ Stack of GeometricTransformerLayers with:
941
+ - CM-gated observation at every layer
942
+ - Cross-layer Cayley rotation on hidden states (not geo_residual)
943
+ - Built-in geometric regularization via geometric_losses()
944
+ """
945
+ def __init__(self, name, d_model=512, n_heads=8, n_layers=4,
946
+ n_anchors=32, manifold_dim=256, n_comp=8, d_comp=32,
947
+ context_dim=128, quat_dim=64, dropout=0.1,
948
+ cross_layer_rotation=True, cm_neighbors=3,
949
+ nce_bank_size=4096, nce_temperature=0.1,
950
+ vocab_size=None, max_seq_len=2048):
951
+ super().__init__(name)
952
+ self.d_model = d_model
953
+ self.n_layers = n_layers
954
+ self.n_anchors = n_anchors
955
+ self._pw_dim = n_comp * d_comp
956
+
957
+ if vocab_size is not None:
958
+ self.attach('embed', nn.Embedding(vocab_size, d_model))
959
+ self.attach('pos_embed', nn.Embedding(max_seq_len, d_model))
960
+ self.attach('head', nn.Linear(d_model, vocab_size, bias=False))
961
+
962
+ for i in range(n_layers):
963
+ self.attach(f'layer_{i}', GeometricTransformerLayer(
964
+ f'{name}_L{i}', d_model, n_heads, n_anchors,
965
+ manifold_dim, n_comp, d_comp, context_dim, quat_dim,
966
+ dropout, cm_neighbors))
967
+
968
+ if cross_layer_rotation and n_layers > 1:
969
+ for i in range(n_layers - 1):
970
+ self.attach(f'cross_rot_{i}', CayleyOrthogonal(
971
+ f'{name}_xrot_{i}', d_model))
972
+
973
+ self.attach('final_norm', nn.LayerNorm(d_model))
974
+
975
+ # Cross-stream contrastive (CLIP-style): content CLS vs geometry CLS
976
+ # Two projections map content (d_model) and geometry (pw_dim) to shared space
977
+ if nce_bank_size > 0:
978
+ nce_proj_dim = 128
979
+ self.attach('nce_content_proj', nn.Sequential(
980
+ nn.Linear(d_model, nce_proj_dim),
981
+ nn.GELU(),
982
+ nn.Linear(nce_proj_dim, nce_proj_dim),
983
+ ))
984
+ self.attach('nce_geo_proj', nn.Sequential(
985
+ nn.Linear(self._pw_dim, nce_proj_dim),
986
+ nn.GELU(),
987
+ nn.Linear(nce_proj_dim, nce_proj_dim),
988
+ ))
989
+ self.attach('nce_bank', GeoResidualBank(
990
+ nce_proj_dim, bank_size=nce_bank_size,
991
+ temperature=nce_temperature))
992
+
993
+ self._config = dict(
994
+ d_model=d_model, n_heads=n_heads, n_layers=n_layers,
995
+ n_anchors=n_anchors, manifold_dim=manifold_dim,
996
+ n_comp=n_comp, d_comp=d_comp, context_dim=context_dim,
997
+ quat_dim=quat_dim, dropout=dropout,
998
+ cross_layer_rotation=cross_layer_rotation,
999
+ cm_neighbors=cm_neighbors, vocab_size=vocab_size,
1000
+ nce_bank_size=nce_bank_size, nce_temperature=nce_temperature,
1001
+ )
1002
+
1003
+ @property
1004
+ def config(self):
1005
+ return self._config.copy()
1006
+
1007
+ def geometric_losses(self, cv_target=0.215, cv_weight=0.1, spread_weight=0.01):
1008
+ """Compute geometric regularization from current anchor geometry.
1009
+
1010
+ These losses maintain the constellation in the regime where
1011
+ CM validation, patchwork interpretation, and the full observation
1012
+ pipeline produce meaningful results.
1013
+
1014
+ CV loss: push anchor coefficient of variation toward pentachoron
1015
+ band (0.20-0.23). This is where CM computation has maximal
1016
+ discriminative power β€” anchors are neither too uniform (CVβ‰ˆ0,
1017
+ CM uninformative) nor too clustered (CV>0.3, degenerate simplices).
1018
+
1019
+ Spread loss: penalize positive cosine similarity between anchors.
1020
+ Prevents collapse where multiple anchors occupy the same region,
1021
+ creating redundant measurements and wasting patchwork capacity.
1022
+
1023
+ Returns:
1024
+ dict with 'cv', 'spread', 'geo_total' loss tensors
1025
+ """
1026
+ total_cv = torch.tensor(0.0)
1027
+ total_spread = torch.tensor(0.0)
1028
+ n = 0
1029
+
1030
+ for i in range(self.n_layers):
1031
+ layer = self[f'layer_{i}']
1032
+ anchors = layer['observer'].association.constellation.anchors
1033
+ anchors_n = F.normalize(anchors, dim=-1)
1034
+ A = anchors_n.shape[0]
1035
+
1036
+ # Ensure we're on the right device
1037
+ if n == 0:
1038
+ total_cv = total_cv.to(anchors.device)
1039
+ total_spread = total_spread.to(anchors.device)
1040
+
1041
+ # ── CV loss: pairwise angular distance coefficient of variation ──
1042
+ cos = anchors_n @ anchors_n.T
1043
+ idx = torch.triu_indices(A, A, offset=1, device=cos.device)
1044
+ pairwise_dist = 1.0 - cos[idx[0], idx[1]]
1045
+ cv = pairwise_dist.std() / (pairwise_dist.mean() + 1e-8)
1046
+ total_cv = total_cv + (cv - cv_target).pow(2)
1047
+
1048
+ # ── Spread loss: penalize positive cosine between anchors ──
1049
+ mask = ~torch.eye(A, dtype=torch.bool, device=cos.device)
1050
+ total_spread = total_spread + F.relu(cos[mask]).mean()
1051
+
1052
+ n += 1
1053
+
1054
+ losses = {}
1055
+ if n > 0:
1056
+ losses['cv'] = cv_weight * total_cv / n
1057
+ losses['spread'] = spread_weight * total_spread / n
1058
+ losses['geo_total'] = losses['cv'] + losses['spread']
1059
+ return losses
1060
+
1061
+ def infonce_loss(self, cls_index=0):
1062
+ """Cross-stream contrastive: content queries against decoupled geometry.
1063
+
1064
+ The constellation provides a STABLE geometric reference frame.
1065
+ The content stream needs discriminative correction.
1066
+ The InfoNCE targets weaker content representations by measuring
1067
+ them against the constellation's observation.
1068
+
1069
+ Gradient path (info-side only):
1070
+ - nce_content_proj ← hidden_cls ← transformer ← input (LIVE)
1071
+ - nce_geo_proj ← learns to read detached residual (LIVE proj, FROZEN input)
1072
+ - geo_residual ← constellation/patchwork/geo_proj (DETACHED β€” decoupled)
1073
+
1074
+ The constellation's anchors never see NCE gradient.
1075
+ Both projection heads learn from InfoNCE to find shared space.
1076
+ Content stream receives corrective gradient at weak positions.
1077
+
1078
+ Returns:
1079
+ dict with 'nce': loss tensor, 'nce_acc': retrieval accuracy
1080
+ """
1081
+ if not self.has('nce_bank'):
1082
+ return {}
1083
+
1084
+ hidden = getattr(self, '_last_hidden', None)
1085
+ geo_residual = getattr(self, '_last_geo_residual', None)
1086
+ if hidden is None or geo_residual is None:
1087
+ return {}
1088
+
1089
+ # Content CLS β†’ shared space (LIVE β€” info-side gets gradient)
1090
+ content_cls = self['nce_content_proj'](hidden[:, cls_index])
1091
+
1092
+ # Geo residual CLS β†’ shared space (DETACHED input β€” constellation decoupled)
1093
+ # nce_geo_proj itself IS trainable β€” learns to read the frozen residual
1094
+ geo_cls = self['nce_geo_proj'](geo_residual[:, cls_index].detach())
1095
+
1096
+ loss, acc = self['nce_bank'](content_cls, geo_cls)
1097
+ return {'nce': loss, 'nce_acc': acc}
1098
+
1099
+ @torch.no_grad()
1100
+ def update_nce_bank(self, cls_index=0):
1101
+ """Enqueue projected geo keys into bank. Call AFTER backward."""
1102
+ if not self.has('nce_bank') or not self.has('nce_geo_proj'):
1103
+ return
1104
+
1105
+ geo_residual = getattr(self, '_last_geo_residual', None)
1106
+ if geo_residual is None:
1107
+ return
1108
+
1109
+ geo_cls = self['nce_geo_proj'](geo_residual[:, cls_index].detach())
1110
+ self['nce_bank'].enqueue(F.normalize(geo_cls, dim=-1))
1111
+
1112
+ def anchor_diagnostics(self):
1113
+ """Per-layer anchor health diagnostics. Call for monitoring."""
1114
+ diag = {}
1115
+ for i in range(self.n_layers):
1116
+ layer = self[f'layer_{i}']
1117
+ anchors = layer['observer'].association.constellation.anchors
1118
+ anchors_n = F.normalize(anchors.detach(), dim=-1)
1119
+ A = anchors_n.shape[0]
1120
+
1121
+ cos = anchors_n @ anchors_n.T
1122
+ idx = torch.triu_indices(A, A, offset=1, device=cos.device)
1123
+ pairwise = 1.0 - cos[idx[0], idx[1]]
1124
+ cv = (pairwise.std() / (pairwise.mean() + 1e-8)).item()
1125
+
1126
+ # CM quality per anchor
1127
+ with torch.no_grad():
1128
+ anchor_cm, _ = anchor_neighborhood_cm(
1129
+ anchors_n, layer['cm_gate'].n_neighbors)
1130
+
1131
+ diag[f'layer_{i}'] = {
1132
+ 'anchor_cv': cv,
1133
+ 'mean_pairwise_dist': pairwise.mean().item(),
1134
+ 'min_pairwise_dist': pairwise.min().item(),
1135
+ 'cm_positive_frac': (anchor_cm > 0).float().mean().item(),
1136
+ 'cm_mean': anchor_cm.mean().item(),
1137
+ 'cm_std': anchor_cm.std().item(),
1138
+ }
1139
+ return diag
1140
+
1141
+ def param_report(self):
1142
+ total = 0
1143
+ name = getattr(self, '_tower_name', self.__class__.__name__)
1144
+ print(f"\n {name} β€” parameter report (CM-validated)")
1145
+ print(f" {'Component':<35s} {'Params':>12s}")
1146
+ print(f" {'─'*35} {'─'*12}")
1147
+ for cname, module in self.named_children():
1148
+ n = sum(p.numel() for p in module.parameters())
1149
+ total += n
1150
+ print(f" {cname:<35s} {n:>12,}")
1151
+ print(f" {'─'*35} {'─'*12}")
1152
+ print(f" {'TOTAL':<35s} {total:>12,}")
1153
+ return total
1154
+
1155
+ def forward(self, x, attn_mask=None, key_padding_mask=None,
1156
+ return_geo_state=False):
1157
+ """
1158
+ Returns:
1159
+ out: (B, L, D) transformed hidden states (or logits if head attached)
1160
+ geo_states: list of per-layer geo_state dicts (if return_geo_state)
1161
+
1162
+ Side effect:
1163
+ self._last_geo_residual is set to the final geo_residual (B, L, pw_dim)
1164
+ for use by infonce_loss() and update_nce_bank() without changing the return API.
1165
+ """
1166
+ if self.has('embed') and x.dtype in (torch.long, torch.int32, torch.int64):
1167
+ pos = torch.arange(x.shape[1], device=x.device)
1168
+ x = self['embed'](x) + self['pos_embed'](pos)
1169
+
1170
+ geo_states = []
1171
+ has_xrot = self.has('cross_rot_0')
1172
+ geo_residual = None
1173
+
1174
+ for i in range(self.n_layers):
1175
+ x, geo_residual, geo_state = self[f'layer_{i}'](
1176
+ x, geo_residual=geo_residual,
1177
+ attn_mask=attn_mask, key_padding_mask=key_padding_mask)
1178
+ if return_geo_state:
1179
+ geo_states.append(geo_state)
1180
+ if has_xrot and i < self.n_layers - 1:
1181
+ x = self[f'cross_rot_{i}'](x)
1182
+ # geo_residual NOT rotated β€” lives in patchwork space, basis-independent
1183
+
1184
+ # Cache for cross-stream contrastive: content CLS vs geometry CLS
1185
+ self._last_geo_residual = geo_residual
1186
+ self._last_hidden = x # pre-norm hidden states β€” content representation
1187
+
1188
+ x = self['final_norm'](x)
1189
+ if self.has('head'):
1190
+ x = self['head'](x)
1191
+
1192
+ return (x, geo_states) if return_geo_state else x
1193
+
1194
+ # ── Paired forward + observer loss ──────────────────────────────
1195
+
1196
+ def _run_view(self, x, attn_mask=None, key_padding_mask=None):
1197
+ """Run one view through the full pipeline.
1198
+
1199
+ Returns:
1200
+ features: (B, L, D) transformed hidden states (post-norm)
1201
+ geo_states: list of per-layer geo_state dicts
1202
+ """
1203
+ geo_states = []
1204
+ has_xrot = self.has('cross_rot_0')
1205
+ geo_residual = None
1206
+
1207
+ if self.has('embed') and x.dtype in (torch.long, torch.int32, torch.int64):
1208
+ pos = torch.arange(x.shape[1], device=x.device)
1209
+ x = self['embed'](x) + self['pos_embed'](pos)
1210
+
1211
+ for i in range(self.n_layers):
1212
+ x, geo_residual, geo_state = self[f'layer_{i}'](
1213
+ x, geo_residual=geo_residual,
1214
+ attn_mask=attn_mask, key_padding_mask=key_padding_mask)
1215
+ geo_states.append(geo_state)
1216
+ if has_xrot and i < self.n_layers - 1:
1217
+ x = self[f'cross_rot_{i}'](x)
1218
+
1219
+ x = self['final_norm'](x)
1220
+ return x, geo_states
1221
+
1222
+ def forward_paired(self, x1, x2, cls_index=0,
1223
+ attn_mask=None, key_padding_mask=None):
1224
+ """Dual-view forward for observer loss training.
1225
+
1226
+ Runs both views through the full CM-gated pipeline, extracts
1227
+ CLS-position geometric state from the final layer, and packages
1228
+ into the observe_paired output format expected by observer_loss().
1229
+
1230
+ Args:
1231
+ x1, x2: (B, L, D) two views of input hidden states
1232
+ cls_index: position index for image-level outputs (default 0)
1233
+
1234
+ Returns:
1235
+ output dict matching observer_loss spec:
1236
+ embedding, embedding_aug, patchwork1, patchwork1_aug,
1237
+ bridge1, bridge2, assign1, assign2, cos1, tri1, tri2
1238
+ Plus: features1, features2, geo_states1, geo_states2
1239
+ """
1240
+ feat1, gs1 = self._run_view(x1, attn_mask, key_padding_mask)
1241
+ feat2, gs2 = self._run_view(x2, attn_mask, key_padding_mask)
1242
+
1243
+ # Extract CLS position from final layer geo_state
1244
+ g1 = gs1[-1]
1245
+ g2 = gs2[-1]
1246
+ c = cls_index
1247
+
1248
+ return {
1249
+ # observe_paired format β€” what observer_loss reads
1250
+ 'embedding': g1['embedding'][:, c],
1251
+ 'embedding_aug': g2['embedding'][:, c],
1252
+ 'patchwork1': g1['patchwork'][:, c],
1253
+ 'patchwork1_aug': g2['patchwork'][:, c],
1254
+ 'bridge1': g1['bridge'][:, c],
1255
+ 'bridge2': g2['bridge'][:, c],
1256
+ 'assign1': g1['assignment'][:, c],
1257
+ 'assign2': g2['assignment'][:, c],
1258
+ 'cos1': g1['cos_to_anchors'][:, c],
1259
+ 'tri1': g1['triangulation'][:, c],
1260
+ 'tri2': g2['triangulation'][:, c],
1261
+ # Full features for task head
1262
+ 'features1': feat1,
1263
+ 'features2': feat2,
1264
+ # Diagnostics
1265
+ 'gate_values1': g1['gate_values'][:, c],
1266
+ 'gate_values2': g2['gate_values'][:, c],
1267
+ 'cm_quality1': g1['cm_quality'],
1268
+ 'cm_quality2': g2['cm_quality'],
1269
+ 'geo_states1': gs1,
1270
+ 'geo_states2': gs2,
1271
+ }
1272
+
1273
+ def compute_loss(self, output, targets, cls_index=0,
1274
+ w_ce=1.0, head=None, **loss_kwargs):
1275
+ """Three-domain observer loss through the CM-gated pipeline.
1276
+
1277
+ Follows ConstellationEncoder.compute_loss pattern:
1278
+ observer_loss (geometric + internal) + CE (external)
1279
+
1280
+ The observer_loss reads patchwork, bridge, assign, tri, cos β€”
1281
+ all of which flowed through the CM gate during forward_paired.
1282
+
1283
+ Args:
1284
+ output: dict from forward_paired()
1285
+ targets: (B,) class labels
1286
+ cls_index: which position has the CLS token
1287
+ w_ce: weight on cross-entropy loss
1288
+ head: nn.Module mapping (B, D) β†’ (B, num_classes), or None
1289
+ **loss_kwargs: forwarded to observer_loss (w_nce_pw, w_bridge, etc.)
1290
+
1291
+ Returns:
1292
+ (total_loss, loss_dict)
1293
+ """
1294
+ # Get anchors from final layer's constellation
1295
+ final_layer = self[f'layer_{self.n_layers - 1}']
1296
+ anchors = final_layer['observer'].association.constellation.anchors
1297
+
1298
+ # Observer self-organization loss (geometric + internal)
1299
+ obs_loss, ld = _geolip_observer_loss(
1300
+ output, anchors=anchors, targets=targets,
1301
+ **loss_kwargs)
1302
+
1303
+ # Task loss if head provided
1304
+ if head is not None:
1305
+ feat1 = output['features1'][:, cls_index]
1306
+ feat2 = output['features2'][:, cls_index]
1307
+ logits1 = head(feat1)
1308
+ logits2 = head(feat2)
1309
+ l_ce, acc = _geolip_ce_loss_paired(logits1, logits2, targets)
1310
+ ld['ce'], ld['acc'] = l_ce, acc
1311
+ ld['logits'] = logits1
1312
+ loss = w_ce * l_ce + obs_loss
1313
+ ld['loss_task'] = l_ce.item()
1314
+ else:
1315
+ loss = obs_loss
1316
+
1317
+ # Anchor maintenance across ALL layers (not just final)
1318
+ total_spread = torch.tensor(0.0, device=anchors.device)
1319
+ for i in range(self.n_layers):
1320
+ layer = self[f'layer_{i}']
1321
+ layer_anchors = layer['observer'].association.constellation.anchors
1322
+ total_spread = total_spread + _geolip_spread_loss(layer_anchors)
1323
+ ld['spread_all_layers'] = total_spread / self.n_layers
1324
+
1325
+ ld['loss_observer'] = obs_loss.item()
1326
+ ld['total'] = loss
1327
+ return loss, ld
1328
+
1329
+
1330
+ # ═══════════════════════════════════════════════════════════════════════════════
1331
+ # FACTORIES
1332
+ # ═══════════════════════════════════════════════════════════════════════════════
1333
+
1334
+ def geo_transformer_esm2(name='geo_esm2', n_layers=6, **kw):
1335
+ """Pre-configured for ESM-2 650M (d=1280)."""
1336
+ return GeometricTransformer(name, d_model=1280, n_heads=16,
1337
+ n_layers=n_layers, n_anchors=32, manifold_dim=256,
1338
+ n_comp=8, d_comp=32, context_dim=128, quat_dim=64, **kw)
1339
+
1340
+ def geo_transformer_small(name='geo_small', n_layers=4, **kw):
1341
+ """Small config for prototyping."""
1342
+ return GeometricTransformer(name, d_model=256, n_heads=8,
1343
+ n_layers=n_layers, n_anchors=16, manifold_dim=128,
1344
+ n_comp=4, d_comp=16, context_dim=64, quat_dim=32, **kw)
1345
+
1346
+ def geo_transformer_vision(name='geo_vit', n_layers=4, **kw):
1347
+ """For scatter/SVD vision pipeline (patches as tokens)."""
1348
+ return GeometricTransformer(name, d_model=384, n_heads=8,
1349
+ n_layers=n_layers, n_anchors=32, manifold_dim=128,
1350
+ n_comp=8, d_comp=16, context_dim=64, quat_dim=32, **kw)
1351
+
1352
+
1353
+ # ═══════════════════════════════════════════════════════════════════════════════
1354
+ # SELF-TEST
1355
+ # ═══════════════════════════════════════════════════════════════════════════════
1356
+
1357
+ if __name__ == '__main__':
1358
+ print("Geometric Transformer β€” CM Validated β€” Self-Test")
1359
+ print(f" geolip_core available: {_HAS_GEOLIP}")
1360
+ print("=" * 60)
1361
+
1362
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1363
+
1364
+ # ── Build small model ──
1365
+ model = geo_transformer_small('test_cm', n_layers=2)
1366
+ if hasattr(model, 'network_to'):
1367
+ model.network_to(device=device, strict=False)
1368
+ else:
1369
+ model = model.to(device)
1370
+ total = model.param_report()
1371
+
1372
+ # ── Forward pass ──
1373
+ B, L, D = 2, 32, 256
1374
+ x = torch.randn(B, L, D, device=device)
1375
+ out, geos = model(x, return_geo_state=True)
1376
+
1377
+ assert out.shape == (B, L, D), f"Expected ({B},{L},{D}), got {out.shape}"
1378
+ assert len(geos) == 2
1379
+ print(f"\n Input: ({B}, {L}, {D})")
1380
+ print(f" Output: {out.shape}")
1381
+ print(f" Geo states: {len(geos)} layers")
1382
+
1383
+ # ── Verify CM gate is active ──
1384
+ for i, gs in enumerate(geos):
1385
+ gi = gs['gate_info']
1386
+ cm_q = gs['cm_quality']
1387
+ gv = gs['gate_values']
1388
+ print(f"\n Layer {i} CM gate:")
1389
+ print(f" active anchors: {gi['active'].item():.1f} / {model.n_anchors}")
1390
+ print(f" gate mean: {gi['gate_mean'].item():.4f}")
1391
+ print(f" cm_positive_frac: {gi['cm_positive_frac'].item():.3f}")
1392
+ print(f" gate_values: {gv.shape} range=[{gv.min():.3f}, {gv.max():.3f}]")
1393
+ print(f" cm_quality: {cm_q.shape} mean={cm_q.mean():.4f}")
1394
+
1395
+ # ── Verify geo_residual continuity ──
1396
+ gr0 = geos[0]['geo_residual']
1397
+ gr1 = geos[1]['geo_residual']
1398
+ print(f"\n Geo residual stream:")
1399
+ print(f" Layer 0: {gr0.shape} norm={gr0.norm(dim=-1).mean():.4f}")
1400
+ print(f" Layer 1: {gr1.shape} norm={gr1.norm(dim=-1).mean():.4f}")
1401
+
1402
+ # ── Geometric losses ──
1403
+ geo_losses = model.geometric_losses()
1404
+ print(f"\n Geometric regularization:")
1405
+ for k, v in geo_losses.items():
1406
+ print(f" {k}: {v.item():.6f}")
1407
+
1408
+ # ── Anchor diagnostics ──
1409
+ diag = model.anchor_diagnostics()
1410
+ print(f"\n Anchor diagnostics:")
1411
+ for layer_name, d in diag.items():
1412
+ print(f" {layer_name}:")
1413
+ for k, v in d.items():
1414
+ print(f" {k}: {v:.4f}")
1415
+
1416
+ # ── Verify Cayley rotations ──
1417
+ print(f"\n Cayley rotations:")
1418
+ for name, module in model.named_modules():
1419
+ if isinstance(module, CayleyOrthogonal):
1420
+ R = module.get_rotation()
1421
+ I = torch.eye(R.shape[0], device=R.device)
1422
+ print(f" {name}: β€–RRα΅€-Iβ€–={((R@R.T)-I).norm():.8f} det={torch.det(R):.4f}")
1423
+
1424
+ # ── Gradient flow through CM gate ──
1425
+ print(f"\n Gradient flow test:")
1426
+ model.zero_grad()
1427
+ x_grad = torch.randn(B, L, D, device=device, requires_grad=True)
1428
+ out_grad = model(x_grad)
1429
+ loss = out_grad.sum()
1430
+ loss.backward()
1431
+
1432
+ # Check gate_proj has gradients
1433
+ for i in range(model.n_layers):
1434
+ layer = model[f'layer_{i}']
1435
+ gate_grads = [p.grad is not None and p.grad.abs().sum() > 0
1436
+ for p in layer['cm_gate'].parameters()]
1437
+ print(f" layer_{i} cm_gate grad: {'YES' if all(gate_grads) else 'NO'}")
1438
+
1439
+ # ── Training step simulation ──
1440
+ print(f"\n Training step simulation:")
1441
+ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
1442
+ optimizer.zero_grad()
1443
+
1444
+ x_train = torch.randn(B, L, D, device=device)
1445
+ out_train, states = model(x_train, return_geo_state=True)
1446
+ task_loss = out_train.mean() # dummy
1447
+
1448
+ geo_losses = model.geometric_losses()
1449
+ total_loss = task_loss + geo_losses.get('geo_total', 0.0)
1450
+ total_loss.backward()
1451
+ optimizer.step()
1452
+ print(f" task_loss: {task_loss.item():.4f}")
1453
+ print(f" cv_loss: {geo_losses['cv'].item():.6f}")
1454
+ print(f" spread_loss:{geo_losses['spread'].item():.6f}")
1455
+ print(f" total: {total_loss.item():.4f}")
1456
+
1457
+ # ── Paired forward + observer loss (if geolip_core available) ──
1458
+ if _HAS_GEOLIP:
1459
+ print(f"\n Paired forward + observer loss:")
1460
+ model.zero_grad()
1461
+
1462
+ x1 = torch.randn(B, L, D, device=device)
1463
+ x2 = x1 + 0.1 * torch.randn_like(x1) # view 2 = slight perturbation
1464
+ targets = torch.randint(0, 10, (B,), device=device)
1465
+
1466
+ output = model.forward_paired(x1, x2)
1467
+ print(f" Output keys: {sorted(k for k in output if not k.startswith('geo_'))}")
1468
+ for k in ['embedding', 'patchwork1', 'bridge1', 'assign1', 'tri1']:
1469
+ print(f" {k}: {output[k].shape}")
1470
+
1471
+ # Task head for CE
1472
+ num_classes = 10
1473
+ head = nn.Linear(D, num_classes).to(device)
1474
+
1475
+ loss, ld = model.compute_loss(output, targets, head=head)
1476
+ print(f"\n Three-domain loss breakdown:")
1477
+ for k in ['loss_observer', 'loss_task', 'ce', 'nce_emb', 'nce_pw',
1478
+ 'bridge', 'assign', 'assign_nce', 'nce_tri', 'attract',
1479
+ 'cv', 'spread']:
1480
+ if k in ld:
1481
+ v = ld[k]
1482
+ v = v.item() if isinstance(v, torch.Tensor) else v
1483
+ print(f" {k:16s} = {v:.4f}")
1484
+ for k in ['nce_emb_acc', 'nce_pw_acc', 'nce_tri_acc', 'bridge_acc',
1485
+ 'assign_nce_acc', 'acc']:
1486
+ if k in ld:
1487
+ v = ld[k]
1488
+ v = v if isinstance(v, float) else v.item()
1489
+ print(f" {k:16s} = {v*100:.1f}%")
1490
+ print(f" {'TOTAL':16s} = {loss.item():.4f}")
1491
+
1492
+ # Verify backward through observer loss
1493
+ loss.backward()
1494
+ alive, dead = 0, 0
1495
+ for n, p in model.named_parameters():
1496
+ if p.grad is not None and p.grad.norm() > 0:
1497
+ alive += 1
1498
+ else:
1499
+ dead += 1
1500
+ print(f"\n Gradient flow: {alive} params alive, {dead} dead")
1501
+
1502
+ # Check critical components
1503
+ for i in range(model.n_layers):
1504
+ layer = model[f'layer_{i}']
1505
+ for comp_name in ['cm_gate', 'observer']:
1506
+ has = any(p.grad is not None and p.grad.norm() > 0
1507
+ for p in layer[comp_name].parameters())
1508
+ print(f" layer_{i}.{comp_name}: {'LIVE' if has else 'DEAD'}")
1509
+
1510
+ # Bridge specifically β€” was never used in loss before
1511
+ for i in range(model.n_layers):
1512
+ layer = model[f'layer_{i}']
1513
+ bridge = layer['observer'].curation.bridge
1514
+ has = any(p.grad is not None and p.grad.norm() > 0
1515
+ for p in bridge.parameters())
1516
+ print(f" layer_{i}.bridge: {'LIVE' if has else 'DEAD'}")
1517
+ else:
1518
+ print(f"\n [SKIP] forward_paired + compute_loss require geolip_core imports")
1519
+
1520
+ print(f"\n{'='*60}")
1521
+ print(f" PASSED β€” CM-validated pipeline operational")
1522
+ print(f"{'='*60}")