Create model.py
Browse files
model.py
ADDED
|
@@ -0,0 +1,940 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Geometric Transformer β GeoLIP Pipeline Integration
|
| 3 |
+
=====================================================
|
| 4 |
+
Dual-stream transformer with constellation-routed attention,
|
| 5 |
+
quaternion composition, and per-layer Cayley alignment.
|
| 6 |
+
|
| 7 |
+
Uses REAL geolip_core components:
|
| 8 |
+
core.associate.constellation β ConstellationObserver (anchors + triangulation + patchwork)
|
| 9 |
+
core.curate.gate β AnchorGate (CM determinant validity)
|
| 10 |
+
core.align.procrustes β CayleyOrthogonal rotation in SO(d)
|
| 11 |
+
pipeline.observer β TorchComponent / BaseTower interfaces
|
| 12 |
+
|
| 13 |
+
NEW components (transformer-specific):
|
| 14 |
+
ManifoldProjection β Input stage: hidden_state β S^(d-1)
|
| 15 |
+
PositionGeometricContext β Curation: constellation output β FiLM context
|
| 16 |
+
FiLMLayer β Feature-wise Linear Modulation (proven in Ryan Spearman)
|
| 17 |
+
GeometricAttention β Attention with FiLM on Q,K from curated constellation
|
| 18 |
+
QuaternionCompose β Hamilton product of dual-stream outputs (proven)
|
| 19 |
+
CayleyOrthogonal β SO(d) rotation via Cayley map (proven)
|
| 20 |
+
DualStreamBlock β Content + geometric streams, aligned + composed
|
| 21 |
+
GeometricTransformerLayer β Full layer: project β observe β attend β compose
|
| 22 |
+
GeometricTransformer β Stack of layers with cross-layer rotation
|
| 23 |
+
|
| 24 |
+
Architecture per layer:
|
| 25 |
+
1. ManifoldProjection: h_i β emb_i on S^(d-1) per position
|
| 26 |
+
2. ConstellationObserver: emb_i β {triangulation, assignment, patchwork, bridge}
|
| 27 |
+
3. PositionGeometricContext: constellation output β (B, L, context_dim)
|
| 28 |
+
4. Stream A (content): standard self-attention
|
| 29 |
+
5. Stream B (geometric): attention with FiLM(Q,K | geo_ctx), V unmodulated
|
| 30 |
+
6. CayleyOrthogonal: align B β A basis
|
| 31 |
+
7. QuaternionCompose: w=content, i=aligned_geo, j=disagree, k=agree
|
| 32 |
+
8. Gated residual
|
| 33 |
+
|
| 34 |
+
Design principles from Ryan Spearman (Ο=0.309, 76/84 wins):
|
| 35 |
+
- FiLM on Q,K ONLY β geometry routes attention, V stays pure
|
| 36 |
+
- FiLM on individual arms BEFORE composition, not after
|
| 37 |
+
- Quaternion algebra as structural regularizer (non-commutative coupling)
|
| 38 |
+
- Disagreement arm (j) carries the transferable signal
|
| 39 |
+
- CayleyOrthogonal guarantees pure rotation (det=1 always)
|
| 40 |
+
- Never global average pool β per-position geometric context
|
| 41 |
+
|
| 42 |
+
Usage:
|
| 43 |
+
from geometric_transformer import GeometricTransformer
|
| 44 |
+
|
| 45 |
+
model = GeometricTransformer('geo_xfmr', d_model=512, n_layers=4)
|
| 46 |
+
out = model(hidden_states)
|
| 47 |
+
|
| 48 |
+
# Or as a head on frozen ESM-2:
|
| 49 |
+
model = GeometricTransformer('esm2_geo', d_model=1280, n_layers=6)
|
| 50 |
+
out = model(esm2_hidden_states)
|
| 51 |
+
|
| 52 |
+
Dependencies:
|
| 53 |
+
pip install geolip-core (includes constellation, patchwork, gate, observer interfaces)
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
import math
|
| 57 |
+
import torch
|
| 58 |
+
import torch.nn as nn
|
| 59 |
+
import torch.nn.functional as F
|
| 60 |
+
|
| 61 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 62 |
+
# GEOLIP IMPORTS β real components, not reimplementations
|
| 63 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
from geolip_core.core.associate.constellation import (
|
| 67 |
+
ConstellationObserver, ConstellationAssociation, ConstellationCuration,
|
| 68 |
+
Constellation, init_anchors_repulsion,
|
| 69 |
+
)
|
| 70 |
+
from geolip_core.core.curate.gate import AnchorGate
|
| 71 |
+
from geolip_core.pipeline.observer import (
|
| 72 |
+
TorchComponent, BaseTower, Input, Curation, Distinction,
|
| 73 |
+
)
|
| 74 |
+
_HAS_GEOLIP = True
|
| 75 |
+
except ImportError:
|
| 76 |
+
_HAS_GEOLIP = False
|
| 77 |
+
|
| 78 |
+
# ββ Fallback stubs ββ
|
| 79 |
+
class TorchComponent(nn.Module):
|
| 80 |
+
def __init__(self, name=None, **kwargs):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self._component_name = name or self.__class__.__name__
|
| 83 |
+
|
| 84 |
+
class BaseTower(nn.Module):
|
| 85 |
+
def __init__(self, name=None, **kwargs):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self._tower_name = name or self.__class__.__name__
|
| 88 |
+
self._components = nn.ModuleDict()
|
| 89 |
+
self._cache = {}
|
| 90 |
+
|
| 91 |
+
def attach(self, name, module):
|
| 92 |
+
if isinstance(module, nn.Module):
|
| 93 |
+
self._components[name] = module
|
| 94 |
+
return self
|
| 95 |
+
|
| 96 |
+
def has(self, name):
|
| 97 |
+
return name in self._components
|
| 98 |
+
|
| 99 |
+
def __getitem__(self, key):
|
| 100 |
+
return self._components[key]
|
| 101 |
+
|
| 102 |
+
def cache_set(self, key, value):
|
| 103 |
+
self._cache[key] = value
|
| 104 |
+
|
| 105 |
+
def cache_get(self, key, default=None):
|
| 106 |
+
return self._cache.get(key, default)
|
| 107 |
+
|
| 108 |
+
def cache_clear(self):
|
| 109 |
+
self._cache.clear()
|
| 110 |
+
|
| 111 |
+
Input = TorchComponent
|
| 112 |
+
Curation = TorchComponent
|
| 113 |
+
Distinction = TorchComponent
|
| 114 |
+
|
| 115 |
+
class Constellation(nn.Module):
|
| 116 |
+
"""Learned anchors on S^(d-1). Triangulates input embeddings."""
|
| 117 |
+
def __init__(self, n_anchors, dim, anchor_drop=0.0, anchor_init='repulsion'):
|
| 118 |
+
super().__init__()
|
| 119 |
+
self.n_anchors = n_anchors
|
| 120 |
+
self.dim = dim
|
| 121 |
+
self.anchor_drop = anchor_drop
|
| 122 |
+
anchors = torch.randn(n_anchors, dim)
|
| 123 |
+
# Repulsion-initialized
|
| 124 |
+
anchors = F.normalize(anchors, dim=-1)
|
| 125 |
+
for _ in range(200):
|
| 126 |
+
sim = anchors @ anchors.T
|
| 127 |
+
sim.fill_diagonal_(-2.0)
|
| 128 |
+
anchors = F.normalize(anchors - 0.05 * anchors[sim.argmax(dim=1)], dim=-1)
|
| 129 |
+
self.anchors = nn.Parameter(anchors)
|
| 130 |
+
|
| 131 |
+
def triangulate(self, emb, training=False):
|
| 132 |
+
anchors = F.normalize(self.anchors, dim=-1)
|
| 133 |
+
cos = emb @ anchors.T
|
| 134 |
+
tri = 1.0 - cos
|
| 135 |
+
_, nearest = cos.max(dim=-1)
|
| 136 |
+
return tri, nearest
|
| 137 |
+
|
| 138 |
+
def forward(self, emb, training=False):
|
| 139 |
+
return self.triangulate(emb, training)
|
| 140 |
+
|
| 141 |
+
class ConstellationAssociation(TorchComponent):
|
| 142 |
+
"""Association through constellation anchors."""
|
| 143 |
+
def __init__(self, dim=256, n_anchors=32, anchor_drop=0.0,
|
| 144 |
+
anchor_init='repulsion', assign_temp=0.1, **kwargs):
|
| 145 |
+
super().__init__(**kwargs)
|
| 146 |
+
self.assign_temp = assign_temp
|
| 147 |
+
self.constellation = Constellation(n_anchors, dim, anchor_drop, anchor_init)
|
| 148 |
+
|
| 149 |
+
@property
|
| 150 |
+
def frame_dim(self):
|
| 151 |
+
return self.constellation.n_anchors
|
| 152 |
+
|
| 153 |
+
def associate(self, emb, **context):
|
| 154 |
+
anchors_n = F.normalize(self.constellation.anchors, dim=-1)
|
| 155 |
+
cos = emb @ anchors_n.T
|
| 156 |
+
tri = 1.0 - cos
|
| 157 |
+
_, nearest = cos.max(dim=-1)
|
| 158 |
+
soft_assign = F.softmax(cos / self.assign_temp, dim=-1)
|
| 159 |
+
mag = context.get('mag', None)
|
| 160 |
+
distances_weighted = tri * mag if mag is not None else tri
|
| 161 |
+
return {
|
| 162 |
+
'distances': tri, 'distances_weighted': distances_weighted,
|
| 163 |
+
'cos_to_anchors': cos, 'assignment': soft_assign,
|
| 164 |
+
'nearest': nearest,
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
def forward(self, emb, **context):
|
| 168 |
+
return self.associate(emb, **context)
|
| 169 |
+
|
| 170 |
+
class Patchwork(nn.Module):
|
| 171 |
+
"""Round-robin patchwork compartments."""
|
| 172 |
+
def __init__(self, n_anchors, n_comp=8, d_comp=32, activation='gelu'):
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.n_comp = n_comp
|
| 175 |
+
anchors_per = max(1, n_anchors // n_comp)
|
| 176 |
+
self.compartments = nn.ModuleList([
|
| 177 |
+
nn.Sequential(nn.Linear(anchors_per, d_comp), nn.GELU(), nn.Linear(d_comp, d_comp))
|
| 178 |
+
for _ in range(n_comp)
|
| 179 |
+
])
|
| 180 |
+
self.output_dim = n_comp * d_comp
|
| 181 |
+
self.anchors_per = anchors_per
|
| 182 |
+
|
| 183 |
+
def forward(self, distances):
|
| 184 |
+
parts = []
|
| 185 |
+
for i, comp in enumerate(self.compartments):
|
| 186 |
+
start = i * self.anchors_per
|
| 187 |
+
end = start + self.anchors_per
|
| 188 |
+
chunk = distances[..., start:end]
|
| 189 |
+
if chunk.shape[-1] < self.anchors_per:
|
| 190 |
+
chunk = F.pad(chunk, (0, self.anchors_per - chunk.shape[-1]))
|
| 191 |
+
parts.append(comp(chunk))
|
| 192 |
+
return torch.cat(parts, dim=-1)
|
| 193 |
+
|
| 194 |
+
class ConstellationCuration(Curation):
|
| 195 |
+
"""Curation through patchwork compartments + bridge."""
|
| 196 |
+
def __init__(self, n_anchors=32, dim=256, n_comp=8, d_comp=32,
|
| 197 |
+
activation='gelu', **kwargs):
|
| 198 |
+
super().__init__(**kwargs)
|
| 199 |
+
self.dim = dim
|
| 200 |
+
self.n_anchors = n_anchors
|
| 201 |
+
self.patchwork = Patchwork(n_anchors, n_comp, d_comp, activation)
|
| 202 |
+
pw_dim = self.patchwork.output_dim
|
| 203 |
+
self.bridge = nn.Linear(pw_dim, n_anchors)
|
| 204 |
+
self._feature_dim = n_anchors + pw_dim + dim
|
| 205 |
+
|
| 206 |
+
@property
|
| 207 |
+
def feature_dim(self):
|
| 208 |
+
return self._feature_dim
|
| 209 |
+
|
| 210 |
+
def curate_full(self, association_output, emb=None, **context):
|
| 211 |
+
distances = association_output['distances_weighted']
|
| 212 |
+
assignment = association_output['assignment']
|
| 213 |
+
pw = self.patchwork(distances)
|
| 214 |
+
bridge = self.bridge(pw)
|
| 215 |
+
parts = [assignment, pw]
|
| 216 |
+
if emb is not None:
|
| 217 |
+
parts.append(emb)
|
| 218 |
+
features = torch.cat(parts, dim=-1)
|
| 219 |
+
return {'patchwork': pw, 'bridge': bridge, 'features': features}
|
| 220 |
+
|
| 221 |
+
def forward(self, association_output, emb=None, **context):
|
| 222 |
+
return self.curate_full(association_output, emb=emb, **context)['features']
|
| 223 |
+
|
| 224 |
+
class ConstellationObserver(nn.Module):
|
| 225 |
+
"""Composed association + curation."""
|
| 226 |
+
def __init__(self, dim=256, n_anchors=32, n_comp=8, d_comp=32,
|
| 227 |
+
anchor_drop=0.0, anchor_init='repulsion',
|
| 228 |
+
activation='gelu', assign_temp=0.1):
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.association = ConstellationAssociation(
|
| 231 |
+
dim=dim, n_anchors=n_anchors, anchor_drop=anchor_drop,
|
| 232 |
+
anchor_init=anchor_init, assign_temp=assign_temp)
|
| 233 |
+
self.curation = ConstellationCuration(
|
| 234 |
+
n_anchors=n_anchors, dim=dim, n_comp=n_comp,
|
| 235 |
+
d_comp=d_comp, activation=activation)
|
| 236 |
+
|
| 237 |
+
@property
|
| 238 |
+
def constellation(self):
|
| 239 |
+
return self.association.constellation
|
| 240 |
+
|
| 241 |
+
@property
|
| 242 |
+
def patchwork(self):
|
| 243 |
+
return self.curation.patchwork
|
| 244 |
+
|
| 245 |
+
@property
|
| 246 |
+
def feature_dim(self):
|
| 247 |
+
return self.curation.feature_dim
|
| 248 |
+
|
| 249 |
+
def observe(self, emb, **context):
|
| 250 |
+
a_out = self.association(emb, **context)
|
| 251 |
+
c_out = self.curation.curate_full(a_out, emb=emb, **context)
|
| 252 |
+
return {
|
| 253 |
+
'embedding': emb, 'features': c_out['features'],
|
| 254 |
+
'triangulation': a_out['distances'],
|
| 255 |
+
'cos_to_anchors': a_out['cos_to_anchors'],
|
| 256 |
+
'nearest': a_out['nearest'],
|
| 257 |
+
'assignment': a_out['assignment'],
|
| 258 |
+
'patchwork': c_out['patchwork'], 'bridge': c_out['bridge'],
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
def forward(self, emb, **context):
|
| 262 |
+
return self.observe(emb, **context)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 266 |
+
# PROVEN COMPONENTS β from Ryan Spearman (unchanged, tested)
|
| 267 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 268 |
+
|
| 269 |
+
class FiLMLayer(TorchComponent):
|
| 270 |
+
"""Feature-wise Linear Modulation. Proven in Ryan Spearman.
|
| 271 |
+
|
| 272 |
+
Produces Ξ³ * x + Ξ² from geometric context.
|
| 273 |
+
Identity-initialized: Ξ³=1, Ξ²=0 at init.
|
| 274 |
+
"""
|
| 275 |
+
def __init__(self, name, feature_dim, context_dim):
|
| 276 |
+
super().__init__(name)
|
| 277 |
+
self.to_gamma = nn.Linear(context_dim, feature_dim)
|
| 278 |
+
self.to_beta = nn.Linear(context_dim, feature_dim)
|
| 279 |
+
nn.init.zeros_(self.to_gamma.weight); nn.init.ones_(self.to_gamma.bias)
|
| 280 |
+
nn.init.zeros_(self.to_beta.weight); nn.init.zeros_(self.to_beta.bias)
|
| 281 |
+
|
| 282 |
+
def forward(self, x, ctx):
|
| 283 |
+
"""x: (B, L, D), ctx: (B, L, C) β (B, L, D)"""
|
| 284 |
+
return self.to_gamma(ctx) * x + self.to_beta(ctx)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class CayleyOrthogonal(TorchComponent):
|
| 288 |
+
"""Guaranteed SO(d) rotation via Cayley map. Proven in Procrustes alignment.
|
| 289 |
+
|
| 290 |
+
Q = (I - A)(I + A)^(-1) where A is skew-symmetric.
|
| 291 |
+
det(Q) = 1 always. βR-Iβ β 4.1 at convergence in SO(256).
|
| 292 |
+
|
| 293 |
+
Caches the rotation matrix β only recomputes when A_upper changes
|
| 294 |
+
(i.e. after optimizer.step()). The solve is input-independent.
|
| 295 |
+
"""
|
| 296 |
+
def __init__(self, name, dim):
|
| 297 |
+
super().__init__(name)
|
| 298 |
+
self.dim = dim
|
| 299 |
+
self.A_upper = nn.Parameter(torch.zeros(dim * (dim - 1) // 2) * 0.01)
|
| 300 |
+
self._cached_R = None
|
| 301 |
+
self._cached_A_version = None
|
| 302 |
+
|
| 303 |
+
def _param_version(self):
|
| 304 |
+
"""Track parameter changes via data_ptr + requires_grad state."""
|
| 305 |
+
return self.A_upper.data_ptr(), self.A_upper._version
|
| 306 |
+
|
| 307 |
+
def get_rotation(self):
|
| 308 |
+
# During training: always recompute (autograd graph needed fresh)
|
| 309 |
+
# During eval: cache the rotation (params don't change)
|
| 310 |
+
if self.training:
|
| 311 |
+
self._cached_R = None
|
| 312 |
+
|
| 313 |
+
version = self._param_version()
|
| 314 |
+
if self._cached_R is not None and self._cached_A_version == version:
|
| 315 |
+
return self._cached_R
|
| 316 |
+
|
| 317 |
+
d = self.dim
|
| 318 |
+
A = torch.zeros(d, d, device=self.A_upper.device, dtype=self.A_upper.dtype)
|
| 319 |
+
idx = torch.triu_indices(d, d, offset=1, device=A.device)
|
| 320 |
+
A[idx[0], idx[1]] = self.A_upper
|
| 321 |
+
A = A - A.T
|
| 322 |
+
I = torch.eye(d, device=A.device, dtype=A.dtype)
|
| 323 |
+
R = torch.linalg.solve(I + A, I - A)
|
| 324 |
+
|
| 325 |
+
if not self.training:
|
| 326 |
+
self._cached_R = R
|
| 327 |
+
self._cached_A_version = version
|
| 328 |
+
return R
|
| 329 |
+
|
| 330 |
+
def invalidate_cache(self):
|
| 331 |
+
"""Call after optimizer.step() if needed."""
|
| 332 |
+
self._cached_R = None
|
| 333 |
+
self._cached_A_version = None
|
| 334 |
+
|
| 335 |
+
def forward(self, x):
|
| 336 |
+
"""(..., dim) β (..., dim) rotated."""
|
| 337 |
+
return x @ self.get_rotation().T
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def quaternion_multiply(q1, q2):
|
| 341 |
+
"""Hamilton product. q = (w, x, y, z) along dim=-2.
|
| 342 |
+
|
| 343 |
+
Supports batched: (..., 4, D) Γ (..., 4, D) β (..., 4, D)
|
| 344 |
+
Or scalar: (..., 4) Γ (..., 4) β (..., 4)
|
| 345 |
+
"""
|
| 346 |
+
w1, x1, y1, z1 = q1.unbind(-2) if q1.dim() >= 2 and q1.shape[-2] == 4 else q1.unbind(-1)
|
| 347 |
+
w2, x2, y2, z2 = q2.unbind(-2) if q2.dim() >= 2 and q2.shape[-2] == 4 else q2.unbind(-1)
|
| 348 |
+
stack_dim = -2 if q1.dim() >= 2 and q1.shape[-2] == 4 else -1
|
| 349 |
+
return torch.stack([
|
| 350 |
+
w1*w2 - x1*x2 - y1*y2 - z1*z2,
|
| 351 |
+
w1*x2 + x1*w2 + y1*z2 - z1*y2,
|
| 352 |
+
w1*y2 - x1*z2 + y1*w2 + z1*x2,
|
| 353 |
+
w1*z2 + x1*y2 - y1*x2 + z1*w2,
|
| 354 |
+
], dim=stack_dim)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def quaternion_multiply_batched(q1, q2):
|
| 358 |
+
"""Hamilton product on (B, 4, D) tensors. Fully vectorized, no loops.
|
| 359 |
+
|
| 360 |
+
Each of the 4 slices along dim=1 is one quaternion component.
|
| 361 |
+
The D dimension is batched β all D quaternions multiplied in parallel.
|
| 362 |
+
"""
|
| 363 |
+
w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3]
|
| 364 |
+
w2, x2, y2, z2 = q2[:, 0], q2[:, 1], q2[:, 2], q2[:, 3]
|
| 365 |
+
return torch.stack([
|
| 366 |
+
w1*w2 - x1*x2 - y1*y2 - z1*z2,
|
| 367 |
+
w1*x2 + x1*w2 + y1*z2 - z1*y2,
|
| 368 |
+
w1*y2 - x1*z2 + y1*w2 + z1*x2,
|
| 369 |
+
w1*z2 + x1*y2 - y1*x2 + z1*w2,
|
| 370 |
+
], dim=1) # (B, 4, D)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
class QuaternionCompose(TorchComponent):
|
| 374 |
+
"""Four-arm Hamilton product composition. Proven in GeoQuat head.
|
| 375 |
+
|
| 376 |
+
The algebra forces cross-term interactions between arms.
|
| 377 |
+
Arms cannot independently memorize β the non-commutative
|
| 378 |
+
product couples their outputs as structural regularizer.
|
| 379 |
+
|
| 380 |
+
Fully vectorized: single batched Hamilton product, no Python loops.
|
| 381 |
+
"""
|
| 382 |
+
def __init__(self, name, input_dim, quat_dim=64):
|
| 383 |
+
super().__init__(name)
|
| 384 |
+
self.quat_dim = quat_dim
|
| 385 |
+
self.proj_w = nn.Linear(input_dim, quat_dim)
|
| 386 |
+
self.proj_i = nn.Linear(input_dim, quat_dim)
|
| 387 |
+
self.proj_j = nn.Linear(input_dim, quat_dim)
|
| 388 |
+
self.proj_k = nn.Linear(input_dim, quat_dim)
|
| 389 |
+
self.rotation = nn.Parameter(torch.randn(1, 4, quat_dim) * 0.1)
|
| 390 |
+
|
| 391 |
+
@property
|
| 392 |
+
def output_dim(self):
|
| 393 |
+
return self.quat_dim * 4
|
| 394 |
+
|
| 395 |
+
def forward(self, arm_w, arm_i, arm_j, arm_k):
|
| 396 |
+
"""Each arm: (B, L, D) β composed: (B, L, 4*quat_dim)"""
|
| 397 |
+
shape = arm_w.shape[:-1]
|
| 398 |
+
D = arm_w.shape[-1]
|
| 399 |
+
flat = arm_w.dim() > 2
|
| 400 |
+
if flat:
|
| 401 |
+
arm_w = arm_w.reshape(-1, D); arm_i = arm_i.reshape(-1, D)
|
| 402 |
+
arm_j = arm_j.reshape(-1, D); arm_k = arm_k.reshape(-1, D)
|
| 403 |
+
|
| 404 |
+
# q: (N, 4, quat_dim) β stack 4 projected arms as quaternion components
|
| 405 |
+
q = torch.stack([self.proj_w(arm_w), self.proj_i(arm_i),
|
| 406 |
+
self.proj_j(arm_j), self.proj_k(arm_k)], dim=1)
|
| 407 |
+
q = q / (q.norm(dim=1, keepdim=True) + 1e-8)
|
| 408 |
+
|
| 409 |
+
# r: (N, 4, quat_dim) β broadcast learned rotation
|
| 410 |
+
r = self.rotation.expand(q.shape[0], -1, -1)
|
| 411 |
+
r = r / (r.norm(dim=1, keepdim=True) + 1e-8)
|
| 412 |
+
|
| 413 |
+
# Single batched Hamilton product over all quat_dim simultaneously
|
| 414 |
+
# (N, 4, quat_dim) Γ (N, 4, quat_dim) β (N, 4, quat_dim)
|
| 415 |
+
composed = quaternion_multiply_batched(r, q)
|
| 416 |
+
|
| 417 |
+
# Flatten 4 Γ quat_dim β 4*quat_dim
|
| 418 |
+
composed = composed.reshape(q.shape[0], -1)
|
| 419 |
+
|
| 420 |
+
if flat:
|
| 421 |
+
composed = composed.reshape(*shape, -1)
|
| 422 |
+
return composed
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 426 |
+
# NEW COMPONENTS β transformer-specific, built for this architecture
|
| 427 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 428 |
+
|
| 429 |
+
class ManifoldProjection(TorchComponent):
|
| 430 |
+
"""Input stage: project transformer hidden states to S^(d-1).
|
| 431 |
+
|
| 432 |
+
Per-position, per-layer projection from model space to the
|
| 433 |
+
constellation's embedding space. L2-normalized to sit on the
|
| 434 |
+
unit hypersphere.
|
| 435 |
+
|
| 436 |
+
This is the tap β it reads the representation without modifying it.
|
| 437 |
+
"""
|
| 438 |
+
def __init__(self, name, d_model, manifold_dim):
|
| 439 |
+
super().__init__(name)
|
| 440 |
+
self.proj = nn.Linear(d_model, manifold_dim)
|
| 441 |
+
self.norm = nn.LayerNorm(manifold_dim)
|
| 442 |
+
|
| 443 |
+
def forward(self, hidden_states):
|
| 444 |
+
"""(B, L, D) β (B, L, manifold_dim) on S^(manifold_dim - 1)"""
|
| 445 |
+
h = self.norm(self.proj(hidden_states))
|
| 446 |
+
return F.normalize(h, dim=-1)
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
class PositionGeometricContext(TorchComponent):
|
| 450 |
+
"""Curation stage: constellation observation β FiLM context vector.
|
| 451 |
+
|
| 452 |
+
Takes the full observation dict from ConstellationObserver and fuses
|
| 453 |
+
it into a per-position conditioning vector for FiLM layers.
|
| 454 |
+
|
| 455 |
+
Processes: cos_to_anchors, assignment, patchwork, embedding.
|
| 456 |
+
These are the same features the GeoQuat head used β validated on
|
| 457 |
+
ProteinGym across 84 unseen proteins.
|
| 458 |
+
"""
|
| 459 |
+
def __init__(self, name, n_anchors, pw_dim, manifold_dim, context_dim):
|
| 460 |
+
super().__init__(name)
|
| 461 |
+
# Anchor features: cos + assignment + triangulation = 3 * n_anchors
|
| 462 |
+
self.anchor_mlp = nn.Sequential(
|
| 463 |
+
nn.Linear(n_anchors * 3, context_dim),
|
| 464 |
+
nn.GELU(),
|
| 465 |
+
nn.LayerNorm(context_dim),
|
| 466 |
+
)
|
| 467 |
+
# Structural features: patchwork + embedding
|
| 468 |
+
self.struct_mlp = nn.Sequential(
|
| 469 |
+
nn.Linear(pw_dim + manifold_dim, context_dim),
|
| 470 |
+
nn.GELU(),
|
| 471 |
+
nn.LayerNorm(context_dim),
|
| 472 |
+
)
|
| 473 |
+
# Fuse anchor + structural
|
| 474 |
+
self.fuse = nn.Sequential(
|
| 475 |
+
nn.Linear(context_dim * 2, context_dim),
|
| 476 |
+
nn.GELU(),
|
| 477 |
+
nn.LayerNorm(context_dim),
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
def forward(self, obs_dict):
|
| 481 |
+
"""
|
| 482 |
+
Args:
|
| 483 |
+
obs_dict: from ConstellationObserver.observe(), keys:
|
| 484 |
+
cos_to_anchors: (B*L, A)
|
| 485 |
+
assignment: (B*L, A)
|
| 486 |
+
triangulation: (B*L, A)
|
| 487 |
+
patchwork: (B*L, pw_dim)
|
| 488 |
+
embedding: (B*L, manifold_dim)
|
| 489 |
+
Returns:
|
| 490 |
+
(B*L, context_dim) geometric context
|
| 491 |
+
"""
|
| 492 |
+
anchor_feats = torch.cat([
|
| 493 |
+
obs_dict['cos_to_anchors'],
|
| 494 |
+
obs_dict['assignment'],
|
| 495 |
+
obs_dict['triangulation'],
|
| 496 |
+
], dim=-1)
|
| 497 |
+
|
| 498 |
+
struct_feats = torch.cat([
|
| 499 |
+
obs_dict['patchwork'],
|
| 500 |
+
obs_dict['embedding'],
|
| 501 |
+
], dim=-1)
|
| 502 |
+
|
| 503 |
+
a = self.anchor_mlp(anchor_feats)
|
| 504 |
+
s = self.struct_mlp(struct_feats)
|
| 505 |
+
return self.fuse(torch.cat([a, s], dim=-1))
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
class GeometricAttention(TorchComponent):
|
| 509 |
+
"""Attention with FiLM from curated constellation. Stream B.
|
| 510 |
+
|
| 511 |
+
FiLM modulates Q and K BEFORE attention β the constellation
|
| 512 |
+
position controls WHERE attention flows. V stays unmodulated.
|
| 513 |
+
FiLM between FFN layers conditions the nonlinearity.
|
| 514 |
+
|
| 515 |
+
Proven principle: context before composition, not after.
|
| 516 |
+
"""
|
| 517 |
+
def __init__(self, name, d_model, n_heads=8, context_dim=128, dropout=0.1):
|
| 518 |
+
super().__init__(name)
|
| 519 |
+
self.d_model = d_model
|
| 520 |
+
self.n_heads = n_heads
|
| 521 |
+
self.head_dim = d_model // n_heads
|
| 522 |
+
self.scale = self.head_dim ** -0.5
|
| 523 |
+
|
| 524 |
+
self.w_q = nn.Linear(d_model, d_model)
|
| 525 |
+
self.w_k = nn.Linear(d_model, d_model)
|
| 526 |
+
self.w_v = nn.Linear(d_model, d_model)
|
| 527 |
+
self.w_o = nn.Linear(d_model, d_model)
|
| 528 |
+
self.dropout = nn.Dropout(dropout)
|
| 529 |
+
|
| 530 |
+
# FiLM on Q and K β geometry routes attention
|
| 531 |
+
self.film_q = FiLMLayer(f'{name}_film_q', d_model, context_dim)
|
| 532 |
+
self.film_k = FiLMLayer(f'{name}_film_k', d_model, context_dim)
|
| 533 |
+
|
| 534 |
+
self.norm = nn.LayerNorm(d_model)
|
| 535 |
+
|
| 536 |
+
# FFN with FiLM between layers
|
| 537 |
+
self.ffn1 = nn.Linear(d_model, d_model * 4)
|
| 538 |
+
self.film_ffn = FiLMLayer(f'{name}_film_ffn', d_model * 4, context_dim)
|
| 539 |
+
self.ffn2 = nn.Linear(d_model * 4, d_model)
|
| 540 |
+
self.ffn_drop = nn.Dropout(dropout)
|
| 541 |
+
self.ffn_norm = nn.LayerNorm(d_model)
|
| 542 |
+
|
| 543 |
+
def forward(self, x, geo_ctx, attn_mask=None, key_padding_mask=None):
|
| 544 |
+
"""
|
| 545 |
+
x: (B, L, D), geo_ctx: (B, L, C) β (B, L, D)
|
| 546 |
+
"""
|
| 547 |
+
B, L, D = x.shape
|
| 548 |
+
H, HD = self.n_heads, self.head_dim
|
| 549 |
+
|
| 550 |
+
Q = self.film_q(self.w_q(x), geo_ctx)
|
| 551 |
+
K = self.film_k(self.w_k(x), geo_ctx)
|
| 552 |
+
V = self.w_v(x) # V unmodulated β content stays pure
|
| 553 |
+
|
| 554 |
+
Q = Q.view(B, L, H, HD).transpose(1, 2)
|
| 555 |
+
K = K.view(B, L, H, HD).transpose(1, 2)
|
| 556 |
+
V = V.view(B, L, H, HD).transpose(1, 2)
|
| 557 |
+
|
| 558 |
+
scores = (Q @ K.transpose(-2, -1)) * self.scale
|
| 559 |
+
if attn_mask is not None:
|
| 560 |
+
scores = scores + attn_mask
|
| 561 |
+
if key_padding_mask is not None:
|
| 562 |
+
scores = scores.masked_fill(
|
| 563 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'))
|
| 564 |
+
attn_out = (self.dropout(F.softmax(scores, dim=-1)) @ V)
|
| 565 |
+
attn_out = attn_out.transpose(1, 2).reshape(B, L, D)
|
| 566 |
+
|
| 567 |
+
x = self.norm(x + self.w_o(attn_out))
|
| 568 |
+
|
| 569 |
+
# FFN with geometric FiLM between layers
|
| 570 |
+
h = F.gelu(self.ffn1(x))
|
| 571 |
+
h = self.film_ffn(h, geo_ctx)
|
| 572 |
+
x = self.ffn_norm(x + self.ffn_drop(self.ffn2(h)))
|
| 573 |
+
|
| 574 |
+
return x
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
class ContentAttention(TorchComponent):
|
| 578 |
+
"""Standard self-attention. Stream A. No geometric conditioning."""
|
| 579 |
+
def __init__(self, name, d_model, n_heads=8, dropout=0.1):
|
| 580 |
+
super().__init__(name)
|
| 581 |
+
self.attn = nn.MultiheadAttention(
|
| 582 |
+
d_model, n_heads, dropout=dropout, batch_first=True)
|
| 583 |
+
self.norm = nn.LayerNorm(d_model)
|
| 584 |
+
self.ffn = nn.Sequential(
|
| 585 |
+
nn.Linear(d_model, d_model * 4), nn.GELU(),
|
| 586 |
+
nn.Linear(d_model * 4, d_model), nn.Dropout(dropout))
|
| 587 |
+
self.ffn_norm = nn.LayerNorm(d_model)
|
| 588 |
+
|
| 589 |
+
def forward(self, x, attn_mask=None, key_padding_mask=None):
|
| 590 |
+
a, _ = self.attn(x, x, x, attn_mask=attn_mask,
|
| 591 |
+
key_padding_mask=key_padding_mask)
|
| 592 |
+
x = self.norm(x + a)
|
| 593 |
+
x = self.ffn_norm(x + self.ffn(x))
|
| 594 |
+
return x
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 598 |
+
# LAYER β dual-stream with constellation routing
|
| 599 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββοΏ½οΏ½οΏ½ββββββββββββββββββββββββββββ
|
| 600 |
+
|
| 601 |
+
class GeometricTransformerLayer(BaseTower):
|
| 602 |
+
"""One layer of the geometric transformer.
|
| 603 |
+
|
| 604 |
+
Pipeline per layer:
|
| 605 |
+
1. ManifoldProjection: h_i β emb_i on S^(manifold_dim - 1)
|
| 606 |
+
2. ConstellationObserver: emb_i β {triangulation, assignment, patchwork, ...}
|
| 607 |
+
3. PositionGeometricContext: observation β FiLM context (B, L, context_dim)
|
| 608 |
+
4. ContentAttention (Stream A): standard MHA
|
| 609 |
+
5. GeometricAttention (Stream B): FiLM(Q,K | geo_ctx), V pure
|
| 610 |
+
6. CayleyOrthogonal: align B basis β A basis
|
| 611 |
+
7. QuaternionCompose: w=A, i=aligned_B, j=A-B, k=A*B
|
| 612 |
+
8. Decode + gated residual
|
| 613 |
+
|
| 614 |
+
Access:
|
| 615 |
+
layer['projection'] β ManifoldProjection
|
| 616 |
+
layer['observer'] β ConstellationObserver
|
| 617 |
+
layer['context'] β PositionGeometricContext
|
| 618 |
+
layer['content'] β ContentAttention
|
| 619 |
+
layer['geometric'] β GeometricAttention
|
| 620 |
+
layer['rotation'] β CayleyOrthogonal
|
| 621 |
+
layer['compose'] β QuaternionCompose
|
| 622 |
+
"""
|
| 623 |
+
def __init__(self, name, d_model, n_heads=8, n_anchors=32,
|
| 624 |
+
manifold_dim=256, n_comp=8, d_comp=32,
|
| 625 |
+
context_dim=128, quat_dim=64, dropout=0.1):
|
| 626 |
+
super().__init__(name)
|
| 627 |
+
self.d_model = d_model
|
| 628 |
+
|
| 629 |
+
# 1. Project to manifold
|
| 630 |
+
self.attach('projection', ManifoldProjection(
|
| 631 |
+
f'{name}_proj', d_model, manifold_dim))
|
| 632 |
+
|
| 633 |
+
# 2. Constellation observer (real association + curation)
|
| 634 |
+
self.attach('observer', ConstellationObserver(
|
| 635 |
+
dim=manifold_dim, n_anchors=n_anchors,
|
| 636 |
+
n_comp=n_comp, d_comp=d_comp))
|
| 637 |
+
|
| 638 |
+
# 3. Fuse observation into FiLM context
|
| 639 |
+
pw_dim = self['observer'].curation.patchwork.output_dim
|
| 640 |
+
self.attach('context', PositionGeometricContext(
|
| 641 |
+
f'{name}_ctx', n_anchors, pw_dim, manifold_dim, context_dim))
|
| 642 |
+
|
| 643 |
+
# 4. Stream A: content
|
| 644 |
+
self.attach('content', ContentAttention(
|
| 645 |
+
f'{name}_content', d_model, n_heads, dropout))
|
| 646 |
+
|
| 647 |
+
# 5. Stream B: geometric
|
| 648 |
+
self.attach('geometric', GeometricAttention(
|
| 649 |
+
f'{name}_geo', d_model, n_heads, context_dim, dropout))
|
| 650 |
+
|
| 651 |
+
# 6. Cayley rotation: align B β A
|
| 652 |
+
self.attach('rotation', CayleyOrthogonal(f'{name}_cayley', d_model))
|
| 653 |
+
|
| 654 |
+
# 7. Quaternion composition
|
| 655 |
+
self.attach('compose', QuaternionCompose(
|
| 656 |
+
f'{name}_quat', d_model, quat_dim))
|
| 657 |
+
|
| 658 |
+
# 8. Decode + gate
|
| 659 |
+
self.attach('decode', nn.Sequential(
|
| 660 |
+
nn.Linear(quat_dim * 4, d_model), nn.GELU(), nn.LayerNorm(d_model)))
|
| 661 |
+
self.attach('gate', nn.Sequential(
|
| 662 |
+
nn.Linear(d_model * 2, d_model), nn.Sigmoid()))
|
| 663 |
+
|
| 664 |
+
def forward(self, x, attn_mask=None, key_padding_mask=None):
|
| 665 |
+
"""
|
| 666 |
+
Args:
|
| 667 |
+
x: (B, L, D) input hidden states
|
| 668 |
+
|
| 669 |
+
Returns:
|
| 670 |
+
x_out: (B, L, D) transformed hidden states
|
| 671 |
+
geo_state: dict with full geometric residual:
|
| 672 |
+
'embedding': (B, L, manifold_dim) position on S^(d-1)
|
| 673 |
+
'geo_ctx': (B, L, context_dim) compressed FiLM context
|
| 674 |
+
'triangulation': (B, L, A) cosine distances to anchors
|
| 675 |
+
'cos_to_anchors': (B, L, A) raw cosine similarities
|
| 676 |
+
'assignment': (B, L, A) soft assignment
|
| 677 |
+
'nearest': (B, L) nearest anchor index
|
| 678 |
+
'patchwork': (B, L, pw_dim) compartment features
|
| 679 |
+
'bridge': (B, L, A) patchwork's assignment estimate
|
| 680 |
+
'content': (B, L, D) Stream A output
|
| 681 |
+
'geometric': (B, L, D) Stream B output (pre-rotation)
|
| 682 |
+
'composed': (B, L, 4*quat_dim) raw quaternion composition
|
| 683 |
+
"""
|
| 684 |
+
B, L, D = x.shape
|
| 685 |
+
|
| 686 |
+
# 1. Project to manifold: per-position embedding on S^(d-1)
|
| 687 |
+
emb = self['projection'](x) # (B, L, manifold_dim)
|
| 688 |
+
|
| 689 |
+
# 2. Constellation observation: flatten to (B*L, manifold_dim) for observer
|
| 690 |
+
emb_flat = emb.reshape(B * L, -1)
|
| 691 |
+
obs = self['observer'].observe(emb_flat)
|
| 692 |
+
|
| 693 |
+
# 3. Build FiLM context
|
| 694 |
+
geo_ctx_flat = self['context'](obs) # (B*L, context_dim)
|
| 695 |
+
geo_ctx = geo_ctx_flat.reshape(B, L, -1) # (B, L, context_dim)
|
| 696 |
+
|
| 697 |
+
# 4. Stream A: content attention
|
| 698 |
+
a_out = self['content'](x, attn_mask=attn_mask,
|
| 699 |
+
key_padding_mask=key_padding_mask)
|
| 700 |
+
|
| 701 |
+
# 5. Stream B: geometric attention
|
| 702 |
+
b_out = self['geometric'](x, geo_ctx, attn_mask=attn_mask,
|
| 703 |
+
key_padding_mask=key_padding_mask)
|
| 704 |
+
|
| 705 |
+
# 6. Cayley rotation: align B β A
|
| 706 |
+
b_aligned = self['rotation'](b_out)
|
| 707 |
+
|
| 708 |
+
# 7. Quaternion composition
|
| 709 |
+
# w = content (what does standard attention think?)
|
| 710 |
+
# i = aligned geometry (what does geometric attention think?)
|
| 711 |
+
# j = disagreement (where do they diverge? β the surprise signal)
|
| 712 |
+
# k = agreement (where do they converge? β the confidence signal)
|
| 713 |
+
composed = self['compose'](
|
| 714 |
+
arm_w=a_out, arm_i=b_aligned,
|
| 715 |
+
arm_j=a_out - b_aligned, arm_k=a_out * b_aligned)
|
| 716 |
+
|
| 717 |
+
# 8. Decode + gated residual
|
| 718 |
+
decoded = self['decode'](composed)
|
| 719 |
+
g = self['gate'](torch.cat([x, decoded], dim=-1))
|
| 720 |
+
x_out = g * decoded + (1 - g) * x
|
| 721 |
+
|
| 722 |
+
# 9. Build full geometric state β reshape everything back to (B, L, ...)
|
| 723 |
+
def unflatten(t):
|
| 724 |
+
if t is None: return None
|
| 725 |
+
if t.dim() == 1: return t.reshape(B, L) # (B*L,) β (B, L)
|
| 726 |
+
return t.reshape(B, L, *t.shape[1:]) # (B*L, ...) β (B, L, ...)
|
| 727 |
+
|
| 728 |
+
geo_state = {
|
| 729 |
+
'embedding': emb, # already (B, L, manifold_dim)
|
| 730 |
+
'geo_ctx': geo_ctx, # already (B, L, context_dim)
|
| 731 |
+
'triangulation': unflatten(obs['triangulation']),
|
| 732 |
+
'cos_to_anchors': unflatten(obs['cos_to_anchors']),
|
| 733 |
+
'assignment': unflatten(obs['assignment']),
|
| 734 |
+
'nearest': unflatten(obs['nearest']),
|
| 735 |
+
'patchwork': unflatten(obs['patchwork']),
|
| 736 |
+
'bridge': unflatten(obs['bridge']),
|
| 737 |
+
'content': a_out, # (B, L, D)
|
| 738 |
+
'geometric': b_out, # (B, L, D) pre-rotation
|
| 739 |
+
'composed': composed, # (B, L, 4*quat_dim)
|
| 740 |
+
}
|
| 741 |
+
|
| 742 |
+
return x_out, geo_state
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 746 |
+
# FULL MODEL β stack of layers
|
| 747 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 748 |
+
|
| 749 |
+
class GeometricTransformer(BaseTower):
|
| 750 |
+
"""Geometric Transformer β dual-stream with constellation routing.
|
| 751 |
+
|
| 752 |
+
Stack of GeometricTransformerLayers. Optional cross-layer Cayley
|
| 753 |
+
rotation aligns each layer's output basis to the next layer's
|
| 754 |
+
expected input.
|
| 755 |
+
|
| 756 |
+
Access:
|
| 757 |
+
model['layer_0'] β first layer
|
| 758 |
+
model['cross_rot_0'] β cross-layer rotation 0β1
|
| 759 |
+
model['final_norm'] β output normalization
|
| 760 |
+
|
| 761 |
+
Args:
|
| 762 |
+
name: tower identity
|
| 763 |
+
d_model: transformer model dimension
|
| 764 |
+
n_heads: attention heads per stream
|
| 765 |
+
n_layers: number of geometric transformer layers
|
| 766 |
+
n_anchors: constellation anchor points
|
| 767 |
+
manifold_dim: dimension of S^(d-1) for constellation
|
| 768 |
+
n_comp: patchwork compartments
|
| 769 |
+
d_comp: hidden dim per compartment
|
| 770 |
+
context_dim: FiLM conditioning dimension
|
| 771 |
+
quat_dim: quaternion space dimension
|
| 772 |
+
dropout: dropout rate
|
| 773 |
+
cross_layer_rotation: add Cayley rotation between layers
|
| 774 |
+
vocab_size: if set, adds embedding + output head
|
| 775 |
+
"""
|
| 776 |
+
def __init__(self, name, d_model=512, n_heads=8, n_layers=4,
|
| 777 |
+
n_anchors=32, manifold_dim=256, n_comp=8, d_comp=32,
|
| 778 |
+
context_dim=128, quat_dim=64, dropout=0.1,
|
| 779 |
+
cross_layer_rotation=True, vocab_size=None, max_seq_len=2048):
|
| 780 |
+
super().__init__(name)
|
| 781 |
+
self.d_model = d_model
|
| 782 |
+
self.n_layers = n_layers
|
| 783 |
+
|
| 784 |
+
if vocab_size is not None:
|
| 785 |
+
self.attach('embed', nn.Embedding(vocab_size, d_model))
|
| 786 |
+
self.attach('pos_embed', nn.Embedding(max_seq_len, d_model))
|
| 787 |
+
self.attach('head', nn.Linear(d_model, vocab_size, bias=False))
|
| 788 |
+
|
| 789 |
+
for i in range(n_layers):
|
| 790 |
+
self.attach(f'layer_{i}', GeometricTransformerLayer(
|
| 791 |
+
f'{name}_L{i}', d_model, n_heads, n_anchors,
|
| 792 |
+
manifold_dim, n_comp, d_comp, context_dim, quat_dim, dropout))
|
| 793 |
+
|
| 794 |
+
if cross_layer_rotation and n_layers > 1:
|
| 795 |
+
for i in range(n_layers - 1):
|
| 796 |
+
self.attach(f'cross_rot_{i}', CayleyOrthogonal(
|
| 797 |
+
f'{name}_xrot_{i}', d_model))
|
| 798 |
+
|
| 799 |
+
self.attach('final_norm', nn.LayerNorm(d_model))
|
| 800 |
+
|
| 801 |
+
self._config = dict(
|
| 802 |
+
d_model=d_model, n_heads=n_heads, n_layers=n_layers,
|
| 803 |
+
n_anchors=n_anchors, manifold_dim=manifold_dim,
|
| 804 |
+
n_comp=n_comp, d_comp=d_comp, context_dim=context_dim,
|
| 805 |
+
quat_dim=quat_dim, dropout=dropout,
|
| 806 |
+
cross_layer_rotation=cross_layer_rotation,
|
| 807 |
+
vocab_size=vocab_size,
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
@property
|
| 811 |
+
def config(self):
|
| 812 |
+
return self._config.copy()
|
| 813 |
+
|
| 814 |
+
def param_report(self):
|
| 815 |
+
total = 0
|
| 816 |
+
name = getattr(self, '_tower_name', getattr(self, 'name', self.__class__.__name__))
|
| 817 |
+
print(f"\n {name} β parameter report")
|
| 818 |
+
print(f" {'Component':<35s} {'Params':>12s}")
|
| 819 |
+
print(f" {'β'*35} {'β'*12}")
|
| 820 |
+
for cname, module in self.named_children():
|
| 821 |
+
n = sum(p.numel() for p in module.parameters())
|
| 822 |
+
total += n
|
| 823 |
+
print(f" {cname:<35s} {n:>12,}")
|
| 824 |
+
print(f" {'β'*35} {'β'*12}")
|
| 825 |
+
print(f" {'TOTAL':<35s} {total:>12,}")
|
| 826 |
+
return total
|
| 827 |
+
|
| 828 |
+
def forward(self, x, attn_mask=None, key_padding_mask=None,
|
| 829 |
+
return_geo_state=False):
|
| 830 |
+
"""
|
| 831 |
+
Args:
|
| 832 |
+
x: (B, L, D) hidden states or (B, L) token ids
|
| 833 |
+
return_geo_state: if True, return per-layer geometric state dicts
|
| 834 |
+
|
| 835 |
+
Returns:
|
| 836 |
+
out: (B, L, D) transformed hidden states (or logits if head attached)
|
| 837 |
+
geo_states: list of per-layer geo_state dicts (if return_geo_state)
|
| 838 |
+
Each dict contains: embedding, geo_ctx, triangulation,
|
| 839 |
+
cos_to_anchors, assignment, nearest, patchwork, bridge,
|
| 840 |
+
content, geometric, composed
|
| 841 |
+
"""
|
| 842 |
+
if self.has('embed') and x.dtype in (torch.long, torch.int32, torch.int64):
|
| 843 |
+
pos = torch.arange(x.shape[1], device=x.device)
|
| 844 |
+
x = self['embed'](x) + self['pos_embed'](pos)
|
| 845 |
+
|
| 846 |
+
geo_states = []
|
| 847 |
+
has_xrot = self.has('cross_rot_0')
|
| 848 |
+
|
| 849 |
+
for i in range(self.n_layers):
|
| 850 |
+
x, geo_state = self[f'layer_{i}'](
|
| 851 |
+
x, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
|
| 852 |
+
if return_geo_state:
|
| 853 |
+
geo_states.append(geo_state)
|
| 854 |
+
if has_xrot and i < self.n_layers - 1:
|
| 855 |
+
x = self[f'cross_rot_{i}'](x)
|
| 856 |
+
|
| 857 |
+
x = self['final_norm'](x)
|
| 858 |
+
if self.has('head'):
|
| 859 |
+
x = self['head'](x)
|
| 860 |
+
|
| 861 |
+
return (x, geo_states) if return_geo_state else x
|
| 862 |
+
|
| 863 |
+
|
| 864 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 865 |
+
# FACTORIES
|
| 866 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 867 |
+
|
| 868 |
+
def geo_transformer_esm2(name='geo_esm2', n_layers=6, **kw):
|
| 869 |
+
"""Pre-configured for ESM-2 650M (d=1280)."""
|
| 870 |
+
return GeometricTransformer(name, d_model=1280, n_heads=16,
|
| 871 |
+
n_layers=n_layers, n_anchors=32, manifold_dim=256,
|
| 872 |
+
n_comp=8, d_comp=32, context_dim=128, quat_dim=64, **kw)
|
| 873 |
+
|
| 874 |
+
def geo_transformer_small(name='geo_small', n_layers=4, **kw):
|
| 875 |
+
"""Small config for prototyping."""
|
| 876 |
+
return GeometricTransformer(name, d_model=256, n_heads=8,
|
| 877 |
+
n_layers=n_layers, n_anchors=16, manifold_dim=128,
|
| 878 |
+
n_comp=4, d_comp=16, context_dim=64, quat_dim=32, **kw)
|
| 879 |
+
|
| 880 |
+
def geo_transformer_vision(name='geo_vit', n_layers=4, **kw):
|
| 881 |
+
"""For scatter/SVD vision pipeline (patches as tokens)."""
|
| 882 |
+
return GeometricTransformer(name, d_model=384, n_heads=8,
|
| 883 |
+
n_layers=n_layers, n_anchors=32, manifold_dim=128,
|
| 884 |
+
n_comp=8, d_comp=16, context_dim=64, quat_dim=32, **kw)
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 888 |
+
# SELF-TEST
|
| 889 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 890 |
+
|
| 891 |
+
if __name__ == '__main__':
|
| 892 |
+
print("Geometric Transformer β Self-Test")
|
| 893 |
+
print(f" geolip_core available: {_HAS_GEOLIP}")
|
| 894 |
+
print("=" * 60)
|
| 895 |
+
|
| 896 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 897 |
+
|
| 898 |
+
model = geo_transformer_small('test', n_layers=2)
|
| 899 |
+
if hasattr(model, 'network_to'):
|
| 900 |
+
model.network_to(device=device, strict=False)
|
| 901 |
+
else:
|
| 902 |
+
model = model.to(device)
|
| 903 |
+
total = model.param_report()
|
| 904 |
+
|
| 905 |
+
B, L, D = 2, 32, 256
|
| 906 |
+
x = torch.randn(B, L, D, device=device)
|
| 907 |
+
|
| 908 |
+
out, geos = model(x, return_geo_state=True)
|
| 909 |
+
assert out.shape == (B, L, D), f"Expected ({B},{L},{D}), got {out.shape}"
|
| 910 |
+
assert len(geos) == 2
|
| 911 |
+
|
| 912 |
+
print(f"\n Input: ({B}, {L}, {D})")
|
| 913 |
+
print(f" Output: {out.shape}")
|
| 914 |
+
print(f" Geo states: {len(geos)} layers")
|
| 915 |
+
print(f" State keys: {sorted(geos[0].keys())}")
|
| 916 |
+
for k, v in geos[0].items():
|
| 917 |
+
if v is not None:
|
| 918 |
+
shape = v.shape if hasattr(v, 'shape') else type(v).__name__
|
| 919 |
+
print(f" {k:<18s}: {shape}")
|
| 920 |
+
|
| 921 |
+
# Verify rotations
|
| 922 |
+
for name, module in model.named_modules():
|
| 923 |
+
if isinstance(module, CayleyOrthogonal):
|
| 924 |
+
R = module.get_rotation()
|
| 925 |
+
I = torch.eye(R.shape[0], device=R.device)
|
| 926 |
+
print(f" {name}: βRRα΅-Iβ={((R@R.T)-I).norm():.8f} det={torch.det(R):.4f}")
|
| 927 |
+
|
| 928 |
+
# ESM-2 scale overhead
|
| 929 |
+
print(f"\n ESM-2 scale:")
|
| 930 |
+
esm = geo_transformer_esm2('esm2', n_layers=6)
|
| 931 |
+
if hasattr(esm, 'network_to'):
|
| 932 |
+
esm.network_to(device=device, strict=False)
|
| 933 |
+
else:
|
| 934 |
+
esm = esm.to(device)
|
| 935 |
+
n = esm.param_report()
|
| 936 |
+
print(f" Overhead on 650M base: {n/1e6:.1f}M ({n/650e6*100:.1f}%)")
|
| 937 |
+
|
| 938 |
+
print(f"\n{'='*60}")
|
| 939 |
+
print(f" PASSED")
|
| 940 |
+
print(f"{'='*60}")
|