AbstractPhil commited on
Commit
0577dce
Β·
verified Β·
1 Parent(s): d8fc08a

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +940 -0
model.py ADDED
@@ -0,0 +1,940 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geometric Transformer β€” GeoLIP Pipeline Integration
3
+ =====================================================
4
+ Dual-stream transformer with constellation-routed attention,
5
+ quaternion composition, and per-layer Cayley alignment.
6
+
7
+ Uses REAL geolip_core components:
8
+ core.associate.constellation β€” ConstellationObserver (anchors + triangulation + patchwork)
9
+ core.curate.gate β€” AnchorGate (CM determinant validity)
10
+ core.align.procrustes β€” CayleyOrthogonal rotation in SO(d)
11
+ pipeline.observer β€” TorchComponent / BaseTower interfaces
12
+
13
+ NEW components (transformer-specific):
14
+ ManifoldProjection β€” Input stage: hidden_state β†’ S^(d-1)
15
+ PositionGeometricContext β€” Curation: constellation output β†’ FiLM context
16
+ FiLMLayer β€” Feature-wise Linear Modulation (proven in Ryan Spearman)
17
+ GeometricAttention β€” Attention with FiLM on Q,K from curated constellation
18
+ QuaternionCompose β€” Hamilton product of dual-stream outputs (proven)
19
+ CayleyOrthogonal β€” SO(d) rotation via Cayley map (proven)
20
+ DualStreamBlock β€” Content + geometric streams, aligned + composed
21
+ GeometricTransformerLayer β€” Full layer: project β†’ observe β†’ attend β†’ compose
22
+ GeometricTransformer β€” Stack of layers with cross-layer rotation
23
+
24
+ Architecture per layer:
25
+ 1. ManifoldProjection: h_i β†’ emb_i on S^(d-1) per position
26
+ 2. ConstellationObserver: emb_i β†’ {triangulation, assignment, patchwork, bridge}
27
+ 3. PositionGeometricContext: constellation output β†’ (B, L, context_dim)
28
+ 4. Stream A (content): standard self-attention
29
+ 5. Stream B (geometric): attention with FiLM(Q,K | geo_ctx), V unmodulated
30
+ 6. CayleyOrthogonal: align B β†’ A basis
31
+ 7. QuaternionCompose: w=content, i=aligned_geo, j=disagree, k=agree
32
+ 8. Gated residual
33
+
34
+ Design principles from Ryan Spearman (ρ=0.309, 76/84 wins):
35
+ - FiLM on Q,K ONLY β€” geometry routes attention, V stays pure
36
+ - FiLM on individual arms BEFORE composition, not after
37
+ - Quaternion algebra as structural regularizer (non-commutative coupling)
38
+ - Disagreement arm (j) carries the transferable signal
39
+ - CayleyOrthogonal guarantees pure rotation (det=1 always)
40
+ - Never global average pool β€” per-position geometric context
41
+
42
+ Usage:
43
+ from geometric_transformer import GeometricTransformer
44
+
45
+ model = GeometricTransformer('geo_xfmr', d_model=512, n_layers=4)
46
+ out = model(hidden_states)
47
+
48
+ # Or as a head on frozen ESM-2:
49
+ model = GeometricTransformer('esm2_geo', d_model=1280, n_layers=6)
50
+ out = model(esm2_hidden_states)
51
+
52
+ Dependencies:
53
+ pip install geolip-core (includes constellation, patchwork, gate, observer interfaces)
54
+ """
55
+
56
+ import math
57
+ import torch
58
+ import torch.nn as nn
59
+ import torch.nn.functional as F
60
+
61
+ # ═══════════════════════════════════════════════════════════════════════════════
62
+ # GEOLIP IMPORTS β€” real components, not reimplementations
63
+ # ═══════════════════════════════════════════════════════════════════════════════
64
+
65
+ try:
66
+ from geolip_core.core.associate.constellation import (
67
+ ConstellationObserver, ConstellationAssociation, ConstellationCuration,
68
+ Constellation, init_anchors_repulsion,
69
+ )
70
+ from geolip_core.core.curate.gate import AnchorGate
71
+ from geolip_core.pipeline.observer import (
72
+ TorchComponent, BaseTower, Input, Curation, Distinction,
73
+ )
74
+ _HAS_GEOLIP = True
75
+ except ImportError:
76
+ _HAS_GEOLIP = False
77
+
78
+ # ── Fallback stubs ──
79
+ class TorchComponent(nn.Module):
80
+ def __init__(self, name=None, **kwargs):
81
+ super().__init__()
82
+ self._component_name = name or self.__class__.__name__
83
+
84
+ class BaseTower(nn.Module):
85
+ def __init__(self, name=None, **kwargs):
86
+ super().__init__()
87
+ self._tower_name = name or self.__class__.__name__
88
+ self._components = nn.ModuleDict()
89
+ self._cache = {}
90
+
91
+ def attach(self, name, module):
92
+ if isinstance(module, nn.Module):
93
+ self._components[name] = module
94
+ return self
95
+
96
+ def has(self, name):
97
+ return name in self._components
98
+
99
+ def __getitem__(self, key):
100
+ return self._components[key]
101
+
102
+ def cache_set(self, key, value):
103
+ self._cache[key] = value
104
+
105
+ def cache_get(self, key, default=None):
106
+ return self._cache.get(key, default)
107
+
108
+ def cache_clear(self):
109
+ self._cache.clear()
110
+
111
+ Input = TorchComponent
112
+ Curation = TorchComponent
113
+ Distinction = TorchComponent
114
+
115
+ class Constellation(nn.Module):
116
+ """Learned anchors on S^(d-1). Triangulates input embeddings."""
117
+ def __init__(self, n_anchors, dim, anchor_drop=0.0, anchor_init='repulsion'):
118
+ super().__init__()
119
+ self.n_anchors = n_anchors
120
+ self.dim = dim
121
+ self.anchor_drop = anchor_drop
122
+ anchors = torch.randn(n_anchors, dim)
123
+ # Repulsion-initialized
124
+ anchors = F.normalize(anchors, dim=-1)
125
+ for _ in range(200):
126
+ sim = anchors @ anchors.T
127
+ sim.fill_diagonal_(-2.0)
128
+ anchors = F.normalize(anchors - 0.05 * anchors[sim.argmax(dim=1)], dim=-1)
129
+ self.anchors = nn.Parameter(anchors)
130
+
131
+ def triangulate(self, emb, training=False):
132
+ anchors = F.normalize(self.anchors, dim=-1)
133
+ cos = emb @ anchors.T
134
+ tri = 1.0 - cos
135
+ _, nearest = cos.max(dim=-1)
136
+ return tri, nearest
137
+
138
+ def forward(self, emb, training=False):
139
+ return self.triangulate(emb, training)
140
+
141
+ class ConstellationAssociation(TorchComponent):
142
+ """Association through constellation anchors."""
143
+ def __init__(self, dim=256, n_anchors=32, anchor_drop=0.0,
144
+ anchor_init='repulsion', assign_temp=0.1, **kwargs):
145
+ super().__init__(**kwargs)
146
+ self.assign_temp = assign_temp
147
+ self.constellation = Constellation(n_anchors, dim, anchor_drop, anchor_init)
148
+
149
+ @property
150
+ def frame_dim(self):
151
+ return self.constellation.n_anchors
152
+
153
+ def associate(self, emb, **context):
154
+ anchors_n = F.normalize(self.constellation.anchors, dim=-1)
155
+ cos = emb @ anchors_n.T
156
+ tri = 1.0 - cos
157
+ _, nearest = cos.max(dim=-1)
158
+ soft_assign = F.softmax(cos / self.assign_temp, dim=-1)
159
+ mag = context.get('mag', None)
160
+ distances_weighted = tri * mag if mag is not None else tri
161
+ return {
162
+ 'distances': tri, 'distances_weighted': distances_weighted,
163
+ 'cos_to_anchors': cos, 'assignment': soft_assign,
164
+ 'nearest': nearest,
165
+ }
166
+
167
+ def forward(self, emb, **context):
168
+ return self.associate(emb, **context)
169
+
170
+ class Patchwork(nn.Module):
171
+ """Round-robin patchwork compartments."""
172
+ def __init__(self, n_anchors, n_comp=8, d_comp=32, activation='gelu'):
173
+ super().__init__()
174
+ self.n_comp = n_comp
175
+ anchors_per = max(1, n_anchors // n_comp)
176
+ self.compartments = nn.ModuleList([
177
+ nn.Sequential(nn.Linear(anchors_per, d_comp), nn.GELU(), nn.Linear(d_comp, d_comp))
178
+ for _ in range(n_comp)
179
+ ])
180
+ self.output_dim = n_comp * d_comp
181
+ self.anchors_per = anchors_per
182
+
183
+ def forward(self, distances):
184
+ parts = []
185
+ for i, comp in enumerate(self.compartments):
186
+ start = i * self.anchors_per
187
+ end = start + self.anchors_per
188
+ chunk = distances[..., start:end]
189
+ if chunk.shape[-1] < self.anchors_per:
190
+ chunk = F.pad(chunk, (0, self.anchors_per - chunk.shape[-1]))
191
+ parts.append(comp(chunk))
192
+ return torch.cat(parts, dim=-1)
193
+
194
+ class ConstellationCuration(Curation):
195
+ """Curation through patchwork compartments + bridge."""
196
+ def __init__(self, n_anchors=32, dim=256, n_comp=8, d_comp=32,
197
+ activation='gelu', **kwargs):
198
+ super().__init__(**kwargs)
199
+ self.dim = dim
200
+ self.n_anchors = n_anchors
201
+ self.patchwork = Patchwork(n_anchors, n_comp, d_comp, activation)
202
+ pw_dim = self.patchwork.output_dim
203
+ self.bridge = nn.Linear(pw_dim, n_anchors)
204
+ self._feature_dim = n_anchors + pw_dim + dim
205
+
206
+ @property
207
+ def feature_dim(self):
208
+ return self._feature_dim
209
+
210
+ def curate_full(self, association_output, emb=None, **context):
211
+ distances = association_output['distances_weighted']
212
+ assignment = association_output['assignment']
213
+ pw = self.patchwork(distances)
214
+ bridge = self.bridge(pw)
215
+ parts = [assignment, pw]
216
+ if emb is not None:
217
+ parts.append(emb)
218
+ features = torch.cat(parts, dim=-1)
219
+ return {'patchwork': pw, 'bridge': bridge, 'features': features}
220
+
221
+ def forward(self, association_output, emb=None, **context):
222
+ return self.curate_full(association_output, emb=emb, **context)['features']
223
+
224
+ class ConstellationObserver(nn.Module):
225
+ """Composed association + curation."""
226
+ def __init__(self, dim=256, n_anchors=32, n_comp=8, d_comp=32,
227
+ anchor_drop=0.0, anchor_init='repulsion',
228
+ activation='gelu', assign_temp=0.1):
229
+ super().__init__()
230
+ self.association = ConstellationAssociation(
231
+ dim=dim, n_anchors=n_anchors, anchor_drop=anchor_drop,
232
+ anchor_init=anchor_init, assign_temp=assign_temp)
233
+ self.curation = ConstellationCuration(
234
+ n_anchors=n_anchors, dim=dim, n_comp=n_comp,
235
+ d_comp=d_comp, activation=activation)
236
+
237
+ @property
238
+ def constellation(self):
239
+ return self.association.constellation
240
+
241
+ @property
242
+ def patchwork(self):
243
+ return self.curation.patchwork
244
+
245
+ @property
246
+ def feature_dim(self):
247
+ return self.curation.feature_dim
248
+
249
+ def observe(self, emb, **context):
250
+ a_out = self.association(emb, **context)
251
+ c_out = self.curation.curate_full(a_out, emb=emb, **context)
252
+ return {
253
+ 'embedding': emb, 'features': c_out['features'],
254
+ 'triangulation': a_out['distances'],
255
+ 'cos_to_anchors': a_out['cos_to_anchors'],
256
+ 'nearest': a_out['nearest'],
257
+ 'assignment': a_out['assignment'],
258
+ 'patchwork': c_out['patchwork'], 'bridge': c_out['bridge'],
259
+ }
260
+
261
+ def forward(self, emb, **context):
262
+ return self.observe(emb, **context)
263
+
264
+
265
+ # ═══════════════════════════════════════════════════════════════════════════════
266
+ # PROVEN COMPONENTS β€” from Ryan Spearman (unchanged, tested)
267
+ # ═══════════════════════════════════════════════════════════════════════════════
268
+
269
+ class FiLMLayer(TorchComponent):
270
+ """Feature-wise Linear Modulation. Proven in Ryan Spearman.
271
+
272
+ Produces Ξ³ * x + Ξ² from geometric context.
273
+ Identity-initialized: Ξ³=1, Ξ²=0 at init.
274
+ """
275
+ def __init__(self, name, feature_dim, context_dim):
276
+ super().__init__(name)
277
+ self.to_gamma = nn.Linear(context_dim, feature_dim)
278
+ self.to_beta = nn.Linear(context_dim, feature_dim)
279
+ nn.init.zeros_(self.to_gamma.weight); nn.init.ones_(self.to_gamma.bias)
280
+ nn.init.zeros_(self.to_beta.weight); nn.init.zeros_(self.to_beta.bias)
281
+
282
+ def forward(self, x, ctx):
283
+ """x: (B, L, D), ctx: (B, L, C) β†’ (B, L, D)"""
284
+ return self.to_gamma(ctx) * x + self.to_beta(ctx)
285
+
286
+
287
+ class CayleyOrthogonal(TorchComponent):
288
+ """Guaranteed SO(d) rotation via Cayley map. Proven in Procrustes alignment.
289
+
290
+ Q = (I - A)(I + A)^(-1) where A is skew-symmetric.
291
+ det(Q) = 1 always. β€–R-Iβ€– β‰ˆ 4.1 at convergence in SO(256).
292
+
293
+ Caches the rotation matrix β€” only recomputes when A_upper changes
294
+ (i.e. after optimizer.step()). The solve is input-independent.
295
+ """
296
+ def __init__(self, name, dim):
297
+ super().__init__(name)
298
+ self.dim = dim
299
+ self.A_upper = nn.Parameter(torch.zeros(dim * (dim - 1) // 2) * 0.01)
300
+ self._cached_R = None
301
+ self._cached_A_version = None
302
+
303
+ def _param_version(self):
304
+ """Track parameter changes via data_ptr + requires_grad state."""
305
+ return self.A_upper.data_ptr(), self.A_upper._version
306
+
307
+ def get_rotation(self):
308
+ # During training: always recompute (autograd graph needed fresh)
309
+ # During eval: cache the rotation (params don't change)
310
+ if self.training:
311
+ self._cached_R = None
312
+
313
+ version = self._param_version()
314
+ if self._cached_R is not None and self._cached_A_version == version:
315
+ return self._cached_R
316
+
317
+ d = self.dim
318
+ A = torch.zeros(d, d, device=self.A_upper.device, dtype=self.A_upper.dtype)
319
+ idx = torch.triu_indices(d, d, offset=1, device=A.device)
320
+ A[idx[0], idx[1]] = self.A_upper
321
+ A = A - A.T
322
+ I = torch.eye(d, device=A.device, dtype=A.dtype)
323
+ R = torch.linalg.solve(I + A, I - A)
324
+
325
+ if not self.training:
326
+ self._cached_R = R
327
+ self._cached_A_version = version
328
+ return R
329
+
330
+ def invalidate_cache(self):
331
+ """Call after optimizer.step() if needed."""
332
+ self._cached_R = None
333
+ self._cached_A_version = None
334
+
335
+ def forward(self, x):
336
+ """(..., dim) β†’ (..., dim) rotated."""
337
+ return x @ self.get_rotation().T
338
+
339
+
340
+ def quaternion_multiply(q1, q2):
341
+ """Hamilton product. q = (w, x, y, z) along dim=-2.
342
+
343
+ Supports batched: (..., 4, D) Γ— (..., 4, D) β†’ (..., 4, D)
344
+ Or scalar: (..., 4) Γ— (..., 4) β†’ (..., 4)
345
+ """
346
+ w1, x1, y1, z1 = q1.unbind(-2) if q1.dim() >= 2 and q1.shape[-2] == 4 else q1.unbind(-1)
347
+ w2, x2, y2, z2 = q2.unbind(-2) if q2.dim() >= 2 and q2.shape[-2] == 4 else q2.unbind(-1)
348
+ stack_dim = -2 if q1.dim() >= 2 and q1.shape[-2] == 4 else -1
349
+ return torch.stack([
350
+ w1*w2 - x1*x2 - y1*y2 - z1*z2,
351
+ w1*x2 + x1*w2 + y1*z2 - z1*y2,
352
+ w1*y2 - x1*z2 + y1*w2 + z1*x2,
353
+ w1*z2 + x1*y2 - y1*x2 + z1*w2,
354
+ ], dim=stack_dim)
355
+
356
+
357
+ def quaternion_multiply_batched(q1, q2):
358
+ """Hamilton product on (B, 4, D) tensors. Fully vectorized, no loops.
359
+
360
+ Each of the 4 slices along dim=1 is one quaternion component.
361
+ The D dimension is batched β€” all D quaternions multiplied in parallel.
362
+ """
363
+ w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3]
364
+ w2, x2, y2, z2 = q2[:, 0], q2[:, 1], q2[:, 2], q2[:, 3]
365
+ return torch.stack([
366
+ w1*w2 - x1*x2 - y1*y2 - z1*z2,
367
+ w1*x2 + x1*w2 + y1*z2 - z1*y2,
368
+ w1*y2 - x1*z2 + y1*w2 + z1*x2,
369
+ w1*z2 + x1*y2 - y1*x2 + z1*w2,
370
+ ], dim=1) # (B, 4, D)
371
+
372
+
373
+ class QuaternionCompose(TorchComponent):
374
+ """Four-arm Hamilton product composition. Proven in GeoQuat head.
375
+
376
+ The algebra forces cross-term interactions between arms.
377
+ Arms cannot independently memorize β€” the non-commutative
378
+ product couples their outputs as structural regularizer.
379
+
380
+ Fully vectorized: single batched Hamilton product, no Python loops.
381
+ """
382
+ def __init__(self, name, input_dim, quat_dim=64):
383
+ super().__init__(name)
384
+ self.quat_dim = quat_dim
385
+ self.proj_w = nn.Linear(input_dim, quat_dim)
386
+ self.proj_i = nn.Linear(input_dim, quat_dim)
387
+ self.proj_j = nn.Linear(input_dim, quat_dim)
388
+ self.proj_k = nn.Linear(input_dim, quat_dim)
389
+ self.rotation = nn.Parameter(torch.randn(1, 4, quat_dim) * 0.1)
390
+
391
+ @property
392
+ def output_dim(self):
393
+ return self.quat_dim * 4
394
+
395
+ def forward(self, arm_w, arm_i, arm_j, arm_k):
396
+ """Each arm: (B, L, D) β†’ composed: (B, L, 4*quat_dim)"""
397
+ shape = arm_w.shape[:-1]
398
+ D = arm_w.shape[-1]
399
+ flat = arm_w.dim() > 2
400
+ if flat:
401
+ arm_w = arm_w.reshape(-1, D); arm_i = arm_i.reshape(-1, D)
402
+ arm_j = arm_j.reshape(-1, D); arm_k = arm_k.reshape(-1, D)
403
+
404
+ # q: (N, 4, quat_dim) β€” stack 4 projected arms as quaternion components
405
+ q = torch.stack([self.proj_w(arm_w), self.proj_i(arm_i),
406
+ self.proj_j(arm_j), self.proj_k(arm_k)], dim=1)
407
+ q = q / (q.norm(dim=1, keepdim=True) + 1e-8)
408
+
409
+ # r: (N, 4, quat_dim) β€” broadcast learned rotation
410
+ r = self.rotation.expand(q.shape[0], -1, -1)
411
+ r = r / (r.norm(dim=1, keepdim=True) + 1e-8)
412
+
413
+ # Single batched Hamilton product over all quat_dim simultaneously
414
+ # (N, 4, quat_dim) Γ— (N, 4, quat_dim) β†’ (N, 4, quat_dim)
415
+ composed = quaternion_multiply_batched(r, q)
416
+
417
+ # Flatten 4 Γ— quat_dim β†’ 4*quat_dim
418
+ composed = composed.reshape(q.shape[0], -1)
419
+
420
+ if flat:
421
+ composed = composed.reshape(*shape, -1)
422
+ return composed
423
+
424
+
425
+ # ═══════════════════════════════════════════════════════════════════════════════
426
+ # NEW COMPONENTS β€” transformer-specific, built for this architecture
427
+ # ═══════════════════════════════════════════════════════════════════════════════
428
+
429
+ class ManifoldProjection(TorchComponent):
430
+ """Input stage: project transformer hidden states to S^(d-1).
431
+
432
+ Per-position, per-layer projection from model space to the
433
+ constellation's embedding space. L2-normalized to sit on the
434
+ unit hypersphere.
435
+
436
+ This is the tap β€” it reads the representation without modifying it.
437
+ """
438
+ def __init__(self, name, d_model, manifold_dim):
439
+ super().__init__(name)
440
+ self.proj = nn.Linear(d_model, manifold_dim)
441
+ self.norm = nn.LayerNorm(manifold_dim)
442
+
443
+ def forward(self, hidden_states):
444
+ """(B, L, D) β†’ (B, L, manifold_dim) on S^(manifold_dim - 1)"""
445
+ h = self.norm(self.proj(hidden_states))
446
+ return F.normalize(h, dim=-1)
447
+
448
+
449
+ class PositionGeometricContext(TorchComponent):
450
+ """Curation stage: constellation observation β†’ FiLM context vector.
451
+
452
+ Takes the full observation dict from ConstellationObserver and fuses
453
+ it into a per-position conditioning vector for FiLM layers.
454
+
455
+ Processes: cos_to_anchors, assignment, patchwork, embedding.
456
+ These are the same features the GeoQuat head used β€” validated on
457
+ ProteinGym across 84 unseen proteins.
458
+ """
459
+ def __init__(self, name, n_anchors, pw_dim, manifold_dim, context_dim):
460
+ super().__init__(name)
461
+ # Anchor features: cos + assignment + triangulation = 3 * n_anchors
462
+ self.anchor_mlp = nn.Sequential(
463
+ nn.Linear(n_anchors * 3, context_dim),
464
+ nn.GELU(),
465
+ nn.LayerNorm(context_dim),
466
+ )
467
+ # Structural features: patchwork + embedding
468
+ self.struct_mlp = nn.Sequential(
469
+ nn.Linear(pw_dim + manifold_dim, context_dim),
470
+ nn.GELU(),
471
+ nn.LayerNorm(context_dim),
472
+ )
473
+ # Fuse anchor + structural
474
+ self.fuse = nn.Sequential(
475
+ nn.Linear(context_dim * 2, context_dim),
476
+ nn.GELU(),
477
+ nn.LayerNorm(context_dim),
478
+ )
479
+
480
+ def forward(self, obs_dict):
481
+ """
482
+ Args:
483
+ obs_dict: from ConstellationObserver.observe(), keys:
484
+ cos_to_anchors: (B*L, A)
485
+ assignment: (B*L, A)
486
+ triangulation: (B*L, A)
487
+ patchwork: (B*L, pw_dim)
488
+ embedding: (B*L, manifold_dim)
489
+ Returns:
490
+ (B*L, context_dim) geometric context
491
+ """
492
+ anchor_feats = torch.cat([
493
+ obs_dict['cos_to_anchors'],
494
+ obs_dict['assignment'],
495
+ obs_dict['triangulation'],
496
+ ], dim=-1)
497
+
498
+ struct_feats = torch.cat([
499
+ obs_dict['patchwork'],
500
+ obs_dict['embedding'],
501
+ ], dim=-1)
502
+
503
+ a = self.anchor_mlp(anchor_feats)
504
+ s = self.struct_mlp(struct_feats)
505
+ return self.fuse(torch.cat([a, s], dim=-1))
506
+
507
+
508
+ class GeometricAttention(TorchComponent):
509
+ """Attention with FiLM from curated constellation. Stream B.
510
+
511
+ FiLM modulates Q and K BEFORE attention β€” the constellation
512
+ position controls WHERE attention flows. V stays unmodulated.
513
+ FiLM between FFN layers conditions the nonlinearity.
514
+
515
+ Proven principle: context before composition, not after.
516
+ """
517
+ def __init__(self, name, d_model, n_heads=8, context_dim=128, dropout=0.1):
518
+ super().__init__(name)
519
+ self.d_model = d_model
520
+ self.n_heads = n_heads
521
+ self.head_dim = d_model // n_heads
522
+ self.scale = self.head_dim ** -0.5
523
+
524
+ self.w_q = nn.Linear(d_model, d_model)
525
+ self.w_k = nn.Linear(d_model, d_model)
526
+ self.w_v = nn.Linear(d_model, d_model)
527
+ self.w_o = nn.Linear(d_model, d_model)
528
+ self.dropout = nn.Dropout(dropout)
529
+
530
+ # FiLM on Q and K β€” geometry routes attention
531
+ self.film_q = FiLMLayer(f'{name}_film_q', d_model, context_dim)
532
+ self.film_k = FiLMLayer(f'{name}_film_k', d_model, context_dim)
533
+
534
+ self.norm = nn.LayerNorm(d_model)
535
+
536
+ # FFN with FiLM between layers
537
+ self.ffn1 = nn.Linear(d_model, d_model * 4)
538
+ self.film_ffn = FiLMLayer(f'{name}_film_ffn', d_model * 4, context_dim)
539
+ self.ffn2 = nn.Linear(d_model * 4, d_model)
540
+ self.ffn_drop = nn.Dropout(dropout)
541
+ self.ffn_norm = nn.LayerNorm(d_model)
542
+
543
+ def forward(self, x, geo_ctx, attn_mask=None, key_padding_mask=None):
544
+ """
545
+ x: (B, L, D), geo_ctx: (B, L, C) β†’ (B, L, D)
546
+ """
547
+ B, L, D = x.shape
548
+ H, HD = self.n_heads, self.head_dim
549
+
550
+ Q = self.film_q(self.w_q(x), geo_ctx)
551
+ K = self.film_k(self.w_k(x), geo_ctx)
552
+ V = self.w_v(x) # V unmodulated β€” content stays pure
553
+
554
+ Q = Q.view(B, L, H, HD).transpose(1, 2)
555
+ K = K.view(B, L, H, HD).transpose(1, 2)
556
+ V = V.view(B, L, H, HD).transpose(1, 2)
557
+
558
+ scores = (Q @ K.transpose(-2, -1)) * self.scale
559
+ if attn_mask is not None:
560
+ scores = scores + attn_mask
561
+ if key_padding_mask is not None:
562
+ scores = scores.masked_fill(
563
+ key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'))
564
+ attn_out = (self.dropout(F.softmax(scores, dim=-1)) @ V)
565
+ attn_out = attn_out.transpose(1, 2).reshape(B, L, D)
566
+
567
+ x = self.norm(x + self.w_o(attn_out))
568
+
569
+ # FFN with geometric FiLM between layers
570
+ h = F.gelu(self.ffn1(x))
571
+ h = self.film_ffn(h, geo_ctx)
572
+ x = self.ffn_norm(x + self.ffn_drop(self.ffn2(h)))
573
+
574
+ return x
575
+
576
+
577
+ class ContentAttention(TorchComponent):
578
+ """Standard self-attention. Stream A. No geometric conditioning."""
579
+ def __init__(self, name, d_model, n_heads=8, dropout=0.1):
580
+ super().__init__(name)
581
+ self.attn = nn.MultiheadAttention(
582
+ d_model, n_heads, dropout=dropout, batch_first=True)
583
+ self.norm = nn.LayerNorm(d_model)
584
+ self.ffn = nn.Sequential(
585
+ nn.Linear(d_model, d_model * 4), nn.GELU(),
586
+ nn.Linear(d_model * 4, d_model), nn.Dropout(dropout))
587
+ self.ffn_norm = nn.LayerNorm(d_model)
588
+
589
+ def forward(self, x, attn_mask=None, key_padding_mask=None):
590
+ a, _ = self.attn(x, x, x, attn_mask=attn_mask,
591
+ key_padding_mask=key_padding_mask)
592
+ x = self.norm(x + a)
593
+ x = self.ffn_norm(x + self.ffn(x))
594
+ return x
595
+
596
+
597
+ # ═══════════════════════════════════════════════════════════════════════════════
598
+ # LAYER β€” dual-stream with constellation routing
599
+ # ══════════════════════════════════════════════════���════════════════════════════
600
+
601
+ class GeometricTransformerLayer(BaseTower):
602
+ """One layer of the geometric transformer.
603
+
604
+ Pipeline per layer:
605
+ 1. ManifoldProjection: h_i β†’ emb_i on S^(manifold_dim - 1)
606
+ 2. ConstellationObserver: emb_i β†’ {triangulation, assignment, patchwork, ...}
607
+ 3. PositionGeometricContext: observation β†’ FiLM context (B, L, context_dim)
608
+ 4. ContentAttention (Stream A): standard MHA
609
+ 5. GeometricAttention (Stream B): FiLM(Q,K | geo_ctx), V pure
610
+ 6. CayleyOrthogonal: align B basis β†’ A basis
611
+ 7. QuaternionCompose: w=A, i=aligned_B, j=A-B, k=A*B
612
+ 8. Decode + gated residual
613
+
614
+ Access:
615
+ layer['projection'] β†’ ManifoldProjection
616
+ layer['observer'] β†’ ConstellationObserver
617
+ layer['context'] β†’ PositionGeometricContext
618
+ layer['content'] β†’ ContentAttention
619
+ layer['geometric'] β†’ GeometricAttention
620
+ layer['rotation'] β†’ CayleyOrthogonal
621
+ layer['compose'] β†’ QuaternionCompose
622
+ """
623
+ def __init__(self, name, d_model, n_heads=8, n_anchors=32,
624
+ manifold_dim=256, n_comp=8, d_comp=32,
625
+ context_dim=128, quat_dim=64, dropout=0.1):
626
+ super().__init__(name)
627
+ self.d_model = d_model
628
+
629
+ # 1. Project to manifold
630
+ self.attach('projection', ManifoldProjection(
631
+ f'{name}_proj', d_model, manifold_dim))
632
+
633
+ # 2. Constellation observer (real association + curation)
634
+ self.attach('observer', ConstellationObserver(
635
+ dim=manifold_dim, n_anchors=n_anchors,
636
+ n_comp=n_comp, d_comp=d_comp))
637
+
638
+ # 3. Fuse observation into FiLM context
639
+ pw_dim = self['observer'].curation.patchwork.output_dim
640
+ self.attach('context', PositionGeometricContext(
641
+ f'{name}_ctx', n_anchors, pw_dim, manifold_dim, context_dim))
642
+
643
+ # 4. Stream A: content
644
+ self.attach('content', ContentAttention(
645
+ f'{name}_content', d_model, n_heads, dropout))
646
+
647
+ # 5. Stream B: geometric
648
+ self.attach('geometric', GeometricAttention(
649
+ f'{name}_geo', d_model, n_heads, context_dim, dropout))
650
+
651
+ # 6. Cayley rotation: align B β†’ A
652
+ self.attach('rotation', CayleyOrthogonal(f'{name}_cayley', d_model))
653
+
654
+ # 7. Quaternion composition
655
+ self.attach('compose', QuaternionCompose(
656
+ f'{name}_quat', d_model, quat_dim))
657
+
658
+ # 8. Decode + gate
659
+ self.attach('decode', nn.Sequential(
660
+ nn.Linear(quat_dim * 4, d_model), nn.GELU(), nn.LayerNorm(d_model)))
661
+ self.attach('gate', nn.Sequential(
662
+ nn.Linear(d_model * 2, d_model), nn.Sigmoid()))
663
+
664
+ def forward(self, x, attn_mask=None, key_padding_mask=None):
665
+ """
666
+ Args:
667
+ x: (B, L, D) input hidden states
668
+
669
+ Returns:
670
+ x_out: (B, L, D) transformed hidden states
671
+ geo_state: dict with full geometric residual:
672
+ 'embedding': (B, L, manifold_dim) position on S^(d-1)
673
+ 'geo_ctx': (B, L, context_dim) compressed FiLM context
674
+ 'triangulation': (B, L, A) cosine distances to anchors
675
+ 'cos_to_anchors': (B, L, A) raw cosine similarities
676
+ 'assignment': (B, L, A) soft assignment
677
+ 'nearest': (B, L) nearest anchor index
678
+ 'patchwork': (B, L, pw_dim) compartment features
679
+ 'bridge': (B, L, A) patchwork's assignment estimate
680
+ 'content': (B, L, D) Stream A output
681
+ 'geometric': (B, L, D) Stream B output (pre-rotation)
682
+ 'composed': (B, L, 4*quat_dim) raw quaternion composition
683
+ """
684
+ B, L, D = x.shape
685
+
686
+ # 1. Project to manifold: per-position embedding on S^(d-1)
687
+ emb = self['projection'](x) # (B, L, manifold_dim)
688
+
689
+ # 2. Constellation observation: flatten to (B*L, manifold_dim) for observer
690
+ emb_flat = emb.reshape(B * L, -1)
691
+ obs = self['observer'].observe(emb_flat)
692
+
693
+ # 3. Build FiLM context
694
+ geo_ctx_flat = self['context'](obs) # (B*L, context_dim)
695
+ geo_ctx = geo_ctx_flat.reshape(B, L, -1) # (B, L, context_dim)
696
+
697
+ # 4. Stream A: content attention
698
+ a_out = self['content'](x, attn_mask=attn_mask,
699
+ key_padding_mask=key_padding_mask)
700
+
701
+ # 5. Stream B: geometric attention
702
+ b_out = self['geometric'](x, geo_ctx, attn_mask=attn_mask,
703
+ key_padding_mask=key_padding_mask)
704
+
705
+ # 6. Cayley rotation: align B β†’ A
706
+ b_aligned = self['rotation'](b_out)
707
+
708
+ # 7. Quaternion composition
709
+ # w = content (what does standard attention think?)
710
+ # i = aligned geometry (what does geometric attention think?)
711
+ # j = disagreement (where do they diverge? β€” the surprise signal)
712
+ # k = agreement (where do they converge? β€” the confidence signal)
713
+ composed = self['compose'](
714
+ arm_w=a_out, arm_i=b_aligned,
715
+ arm_j=a_out - b_aligned, arm_k=a_out * b_aligned)
716
+
717
+ # 8. Decode + gated residual
718
+ decoded = self['decode'](composed)
719
+ g = self['gate'](torch.cat([x, decoded], dim=-1))
720
+ x_out = g * decoded + (1 - g) * x
721
+
722
+ # 9. Build full geometric state β€” reshape everything back to (B, L, ...)
723
+ def unflatten(t):
724
+ if t is None: return None
725
+ if t.dim() == 1: return t.reshape(B, L) # (B*L,) β†’ (B, L)
726
+ return t.reshape(B, L, *t.shape[1:]) # (B*L, ...) β†’ (B, L, ...)
727
+
728
+ geo_state = {
729
+ 'embedding': emb, # already (B, L, manifold_dim)
730
+ 'geo_ctx': geo_ctx, # already (B, L, context_dim)
731
+ 'triangulation': unflatten(obs['triangulation']),
732
+ 'cos_to_anchors': unflatten(obs['cos_to_anchors']),
733
+ 'assignment': unflatten(obs['assignment']),
734
+ 'nearest': unflatten(obs['nearest']),
735
+ 'patchwork': unflatten(obs['patchwork']),
736
+ 'bridge': unflatten(obs['bridge']),
737
+ 'content': a_out, # (B, L, D)
738
+ 'geometric': b_out, # (B, L, D) pre-rotation
739
+ 'composed': composed, # (B, L, 4*quat_dim)
740
+ }
741
+
742
+ return x_out, geo_state
743
+
744
+
745
+ # ═══════════════════════════════════════════════════════════════════════════════
746
+ # FULL MODEL β€” stack of layers
747
+ # ═══════════════════════════════════════════════════════════════════════════════
748
+
749
+ class GeometricTransformer(BaseTower):
750
+ """Geometric Transformer β€” dual-stream with constellation routing.
751
+
752
+ Stack of GeometricTransformerLayers. Optional cross-layer Cayley
753
+ rotation aligns each layer's output basis to the next layer's
754
+ expected input.
755
+
756
+ Access:
757
+ model['layer_0'] β†’ first layer
758
+ model['cross_rot_0'] β†’ cross-layer rotation 0β†’1
759
+ model['final_norm'] β†’ output normalization
760
+
761
+ Args:
762
+ name: tower identity
763
+ d_model: transformer model dimension
764
+ n_heads: attention heads per stream
765
+ n_layers: number of geometric transformer layers
766
+ n_anchors: constellation anchor points
767
+ manifold_dim: dimension of S^(d-1) for constellation
768
+ n_comp: patchwork compartments
769
+ d_comp: hidden dim per compartment
770
+ context_dim: FiLM conditioning dimension
771
+ quat_dim: quaternion space dimension
772
+ dropout: dropout rate
773
+ cross_layer_rotation: add Cayley rotation between layers
774
+ vocab_size: if set, adds embedding + output head
775
+ """
776
+ def __init__(self, name, d_model=512, n_heads=8, n_layers=4,
777
+ n_anchors=32, manifold_dim=256, n_comp=8, d_comp=32,
778
+ context_dim=128, quat_dim=64, dropout=0.1,
779
+ cross_layer_rotation=True, vocab_size=None, max_seq_len=2048):
780
+ super().__init__(name)
781
+ self.d_model = d_model
782
+ self.n_layers = n_layers
783
+
784
+ if vocab_size is not None:
785
+ self.attach('embed', nn.Embedding(vocab_size, d_model))
786
+ self.attach('pos_embed', nn.Embedding(max_seq_len, d_model))
787
+ self.attach('head', nn.Linear(d_model, vocab_size, bias=False))
788
+
789
+ for i in range(n_layers):
790
+ self.attach(f'layer_{i}', GeometricTransformerLayer(
791
+ f'{name}_L{i}', d_model, n_heads, n_anchors,
792
+ manifold_dim, n_comp, d_comp, context_dim, quat_dim, dropout))
793
+
794
+ if cross_layer_rotation and n_layers > 1:
795
+ for i in range(n_layers - 1):
796
+ self.attach(f'cross_rot_{i}', CayleyOrthogonal(
797
+ f'{name}_xrot_{i}', d_model))
798
+
799
+ self.attach('final_norm', nn.LayerNorm(d_model))
800
+
801
+ self._config = dict(
802
+ d_model=d_model, n_heads=n_heads, n_layers=n_layers,
803
+ n_anchors=n_anchors, manifold_dim=manifold_dim,
804
+ n_comp=n_comp, d_comp=d_comp, context_dim=context_dim,
805
+ quat_dim=quat_dim, dropout=dropout,
806
+ cross_layer_rotation=cross_layer_rotation,
807
+ vocab_size=vocab_size,
808
+ )
809
+
810
+ @property
811
+ def config(self):
812
+ return self._config.copy()
813
+
814
+ def param_report(self):
815
+ total = 0
816
+ name = getattr(self, '_tower_name', getattr(self, 'name', self.__class__.__name__))
817
+ print(f"\n {name} β€” parameter report")
818
+ print(f" {'Component':<35s} {'Params':>12s}")
819
+ print(f" {'─'*35} {'─'*12}")
820
+ for cname, module in self.named_children():
821
+ n = sum(p.numel() for p in module.parameters())
822
+ total += n
823
+ print(f" {cname:<35s} {n:>12,}")
824
+ print(f" {'─'*35} {'─'*12}")
825
+ print(f" {'TOTAL':<35s} {total:>12,}")
826
+ return total
827
+
828
+ def forward(self, x, attn_mask=None, key_padding_mask=None,
829
+ return_geo_state=False):
830
+ """
831
+ Args:
832
+ x: (B, L, D) hidden states or (B, L) token ids
833
+ return_geo_state: if True, return per-layer geometric state dicts
834
+
835
+ Returns:
836
+ out: (B, L, D) transformed hidden states (or logits if head attached)
837
+ geo_states: list of per-layer geo_state dicts (if return_geo_state)
838
+ Each dict contains: embedding, geo_ctx, triangulation,
839
+ cos_to_anchors, assignment, nearest, patchwork, bridge,
840
+ content, geometric, composed
841
+ """
842
+ if self.has('embed') and x.dtype in (torch.long, torch.int32, torch.int64):
843
+ pos = torch.arange(x.shape[1], device=x.device)
844
+ x = self['embed'](x) + self['pos_embed'](pos)
845
+
846
+ geo_states = []
847
+ has_xrot = self.has('cross_rot_0')
848
+
849
+ for i in range(self.n_layers):
850
+ x, geo_state = self[f'layer_{i}'](
851
+ x, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
852
+ if return_geo_state:
853
+ geo_states.append(geo_state)
854
+ if has_xrot and i < self.n_layers - 1:
855
+ x = self[f'cross_rot_{i}'](x)
856
+
857
+ x = self['final_norm'](x)
858
+ if self.has('head'):
859
+ x = self['head'](x)
860
+
861
+ return (x, geo_states) if return_geo_state else x
862
+
863
+
864
+ # ═══════════════════════════════════════════════════════════════════════════════
865
+ # FACTORIES
866
+ # ═══════════════════════════════════════════════════════════════════════════════
867
+
868
+ def geo_transformer_esm2(name='geo_esm2', n_layers=6, **kw):
869
+ """Pre-configured for ESM-2 650M (d=1280)."""
870
+ return GeometricTransformer(name, d_model=1280, n_heads=16,
871
+ n_layers=n_layers, n_anchors=32, manifold_dim=256,
872
+ n_comp=8, d_comp=32, context_dim=128, quat_dim=64, **kw)
873
+
874
+ def geo_transformer_small(name='geo_small', n_layers=4, **kw):
875
+ """Small config for prototyping."""
876
+ return GeometricTransformer(name, d_model=256, n_heads=8,
877
+ n_layers=n_layers, n_anchors=16, manifold_dim=128,
878
+ n_comp=4, d_comp=16, context_dim=64, quat_dim=32, **kw)
879
+
880
+ def geo_transformer_vision(name='geo_vit', n_layers=4, **kw):
881
+ """For scatter/SVD vision pipeline (patches as tokens)."""
882
+ return GeometricTransformer(name, d_model=384, n_heads=8,
883
+ n_layers=n_layers, n_anchors=32, manifold_dim=128,
884
+ n_comp=8, d_comp=16, context_dim=64, quat_dim=32, **kw)
885
+
886
+
887
+ # ═══════════════════════════════════════════════════════════════════════════════
888
+ # SELF-TEST
889
+ # ═══════════════════════════════════════════════════════════════════════════════
890
+
891
+ if __name__ == '__main__':
892
+ print("Geometric Transformer β€” Self-Test")
893
+ print(f" geolip_core available: {_HAS_GEOLIP}")
894
+ print("=" * 60)
895
+
896
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
897
+
898
+ model = geo_transformer_small('test', n_layers=2)
899
+ if hasattr(model, 'network_to'):
900
+ model.network_to(device=device, strict=False)
901
+ else:
902
+ model = model.to(device)
903
+ total = model.param_report()
904
+
905
+ B, L, D = 2, 32, 256
906
+ x = torch.randn(B, L, D, device=device)
907
+
908
+ out, geos = model(x, return_geo_state=True)
909
+ assert out.shape == (B, L, D), f"Expected ({B},{L},{D}), got {out.shape}"
910
+ assert len(geos) == 2
911
+
912
+ print(f"\n Input: ({B}, {L}, {D})")
913
+ print(f" Output: {out.shape}")
914
+ print(f" Geo states: {len(geos)} layers")
915
+ print(f" State keys: {sorted(geos[0].keys())}")
916
+ for k, v in geos[0].items():
917
+ if v is not None:
918
+ shape = v.shape if hasattr(v, 'shape') else type(v).__name__
919
+ print(f" {k:<18s}: {shape}")
920
+
921
+ # Verify rotations
922
+ for name, module in model.named_modules():
923
+ if isinstance(module, CayleyOrthogonal):
924
+ R = module.get_rotation()
925
+ I = torch.eye(R.shape[0], device=R.device)
926
+ print(f" {name}: β€–RRα΅€-Iβ€–={((R@R.T)-I).norm():.8f} det={torch.det(R):.4f}")
927
+
928
+ # ESM-2 scale overhead
929
+ print(f"\n ESM-2 scale:")
930
+ esm = geo_transformer_esm2('esm2', n_layers=6)
931
+ if hasattr(esm, 'network_to'):
932
+ esm.network_to(device=device, strict=False)
933
+ else:
934
+ esm = esm.to(device)
935
+ n = esm.param_report()
936
+ print(f" Overhead on 650M base: {n/1e6:.1f}M ({n/650e6*100:.1f}%)")
937
+
938
+ print(f"\n{'='*60}")
939
+ print(f" PASSED")
940
+ print(f"{'='*60}")