| """ |
| Geometric Transformer β CM-Validated Pipeline |
| ================================================== |
| Dual-stream transformer with CM-gated constellation observation, |
| quaternion composition, and per-layer Cayley alignment. |
| |
| CM-validated pipeline changes: |
| - CM validity gate between association and curation (AnchorGate) |
| - 4-stream PositionGeometricContext: anchor + structural + history + quality |
| - CM-conditioned geometric residual accumulation (replaces blind learned gate) |
| - Built-in geometric regularization (CV target + anchor spread) |
| - Decomposed observer pipeline: association β CM gate β gated curation |
| |
| Pipeline per layer: |
| 1. ManifoldProjection: h_i β emb_i on S^(d-1) per position |
| 2. ConstellationAssociation: emb_i β raw triangulation, cos, assignment |
| 3. CMValidatedGate: per-anchor CM validity β gate_values (B*L, A) |
| 4. Gated curation: patchwork reads tri * gate_values (validated only) |
| 5. PositionGeometricContext: 4 streams β FiLM context (B, L, context_dim) |
| 6. ContentAttention (Stream A): standard MHA |
| 7. GeometricAttention (Stream B): FiLM(Q,K | geo_ctx), V pure |
| 8. CayleyOrthogonal: align B β A basis |
| 9. QuaternionCompose: w=A, i=aligned_B, j=A-B, k=A*B |
| 10. Decode + gated residual |
| 11. CM-conditioned geometric residual write |
| |
| Geometric regularization (call model.geometric_losses() during training): |
| - CV loss: anchor CV β pentachoron band (0.20-0.23) |
| - Spread loss: prevent anchor collapse (penalize positive cosine) |
| These maintain the constellation in the regime where CM validation works. |
| |
| Design principles from Ryan Spearman (Ο=0.309, 76/84 wins): |
| - FiLM on Q,K ONLY β geometry routes attention, V stays pure |
| - FiLM on individual arms BEFORE composition, not after |
| - Quaternion algebra as structural regularizer (non-commutative coupling) |
| - CayleyOrthogonal guarantees pure rotation (det=1 always) |
| - Never global average pool β per-position geometric context |
| |
| Author: AbstractPhil + Claude Opus 4.6 |
| License: Apache 2.0 |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
| |
| |
|
|
| try: |
| from geolip_core.core.associate.constellation import ( |
| ConstellationObserver, ConstellationAssociation, ConstellationCuration, |
| Constellation, init_anchors_repulsion, |
| ) |
| from geolip_core.core.curate.gate import AnchorGate as _GeolipAnchorGate |
| from geolip_core.pipeline.observer import ( |
| TorchComponent, BaseTower, Input, Curation, Distinction, |
| ) |
| from geolip_core.core.distinguish.losses import ( |
| observer_loss as _geolip_observer_loss, |
| ce_loss_paired as _geolip_ce_loss_paired, |
| cv_loss as _geolip_cv_loss, |
| spread_loss as _geolip_spread_loss, |
| ) |
| _HAS_GEOLIP = True |
| except ImportError: |
| _HAS_GEOLIP = False |
|
|
| |
| class TorchComponent(nn.Module): |
| def __init__(self, name=None, **kwargs): |
| super().__init__() |
| self._component_name = name or self.__class__.__name__ |
|
|
| class BaseTower(nn.Module): |
| def __init__(self, name=None, **kwargs): |
| super().__init__() |
| self._tower_name = name or self.__class__.__name__ |
| self._components = nn.ModuleDict() |
| self._cache = {} |
|
|
| def attach(self, name, module): |
| if isinstance(module, nn.Module): |
| self._components[name] = module |
| return self |
|
|
| def has(self, name): |
| return name in self._components |
|
|
| def __getitem__(self, key): |
| return self._components[key] |
|
|
| def cache_set(self, key, value): |
| self._cache[key] = value |
|
|
| def cache_get(self, key, default=None): |
| return self._cache.get(key, default) |
|
|
| def cache_clear(self): |
| self._cache.clear() |
|
|
| Input = TorchComponent |
| Curation = TorchComponent |
| Distinction = TorchComponent |
|
|
| class Constellation(nn.Module): |
| """Learned anchors on S^(d-1). Triangulates input embeddings.""" |
| def __init__(self, n_anchors, dim, anchor_drop=0.0, anchor_init='repulsion'): |
| super().__init__() |
| self.n_anchors = n_anchors |
| self.dim = dim |
| anchors = torch.randn(n_anchors, dim) |
| anchors = F.normalize(anchors, dim=-1) |
| for _ in range(200): |
| sim = anchors @ anchors.T |
| sim.fill_diagonal_(-2.0) |
| anchors = F.normalize(anchors - 0.05 * anchors[sim.argmax(dim=1)], dim=-1) |
| self.anchors = nn.Parameter(anchors) |
|
|
| def forward(self, emb, training=False): |
| anchors = F.normalize(self.anchors, dim=-1) |
| cos = emb @ anchors.T |
| tri = 1.0 - cos |
| _, nearest = cos.max(dim=-1) |
| return tri, nearest |
|
|
| class ConstellationAssociation(TorchComponent): |
| """Association through constellation anchors.""" |
| def __init__(self, dim=256, n_anchors=32, anchor_drop=0.0, |
| anchor_init='repulsion', assign_temp=0.1, **kwargs): |
| super().__init__(**kwargs) |
| self.assign_temp = assign_temp |
| self.constellation = Constellation(n_anchors, dim, anchor_drop, anchor_init) |
|
|
| @property |
| def frame_dim(self): |
| return self.constellation.n_anchors |
|
|
| def associate(self, emb, **context): |
| anchors_n = F.normalize(self.constellation.anchors, dim=-1) |
| cos = emb @ anchors_n.T |
| tri = 1.0 - cos |
| _, nearest = cos.max(dim=-1) |
| soft_assign = F.softmax(cos / self.assign_temp, dim=-1) |
| mag = context.get('mag', None) |
| distances_weighted = tri * mag if mag is not None else tri |
| return { |
| 'distances': tri, 'distances_weighted': distances_weighted, |
| 'cos_to_anchors': cos, 'assignment': soft_assign, |
| 'nearest': nearest, |
| } |
|
|
| def forward(self, emb, **context): |
| return self.associate(emb, **context) |
|
|
| class Patchwork(nn.Module): |
| """Round-robin patchwork compartments.""" |
| def __init__(self, n_anchors, n_comp=8, d_comp=32, activation='gelu'): |
| super().__init__() |
| self.n_comp = n_comp |
| anchors_per = max(1, n_anchors // n_comp) |
| self.compartments = nn.ModuleList([ |
| nn.Sequential(nn.Linear(anchors_per, d_comp), nn.GELU(), nn.Linear(d_comp, d_comp)) |
| for _ in range(n_comp) |
| ]) |
| self.output_dim = n_comp * d_comp |
| self.anchors_per = anchors_per |
|
|
| def forward(self, distances): |
| parts = [] |
| for i, comp in enumerate(self.compartments): |
| start = i * self.anchors_per |
| end = start + self.anchors_per |
| chunk = distances[..., start:end] |
| if chunk.shape[-1] < self.anchors_per: |
| chunk = F.pad(chunk, (0, self.anchors_per - chunk.shape[-1])) |
| parts.append(comp(chunk)) |
| return torch.cat(parts, dim=-1) |
|
|
| class ConstellationCuration(Curation): |
| """Curation through patchwork compartments + bridge.""" |
| def __init__(self, n_anchors=32, dim=256, n_comp=8, d_comp=32, |
| activation='gelu', **kwargs): |
| super().__init__(**kwargs) |
| self.dim = dim |
| self.n_anchors = n_anchors |
| self.patchwork = Patchwork(n_anchors, n_comp, d_comp, activation) |
| pw_dim = self.patchwork.output_dim |
| self.bridge = nn.Linear(pw_dim, n_anchors) |
| self._feature_dim = n_anchors + pw_dim + dim |
|
|
| @property |
| def feature_dim(self): |
| return self._feature_dim |
|
|
| def curate_full(self, association_output, emb=None, **context): |
| distances = association_output['distances_weighted'] |
| assignment = association_output['assignment'] |
| pw = self.patchwork(distances) |
| bridge = self.bridge(pw) |
| parts = [assignment, pw] |
| if emb is not None: |
| parts.append(emb) |
| features = torch.cat(parts, dim=-1) |
| return {'patchwork': pw, 'bridge': bridge, 'features': features} |
|
|
| def forward(self, association_output, emb=None, **context): |
| return self.curate_full(association_output, emb=emb, **context)['features'] |
|
|
| class ConstellationObserver(nn.Module): |
| """Composed association + curation.""" |
| def __init__(self, dim=256, n_anchors=32, n_comp=8, d_comp=32, |
| anchor_drop=0.0, anchor_init='repulsion', |
| activation='gelu', assign_temp=0.1): |
| super().__init__() |
| self.association = ConstellationAssociation( |
| dim=dim, n_anchors=n_anchors, anchor_drop=anchor_drop, |
| anchor_init=anchor_init, assign_temp=assign_temp) |
| self.curation = ConstellationCuration( |
| n_anchors=n_anchors, dim=dim, n_comp=n_comp, |
| d_comp=d_comp, activation=activation) |
|
|
| @property |
| def constellation(self): |
| return self.association.constellation |
|
|
| @property |
| def patchwork(self): |
| return self.curation.patchwork |
|
|
| @property |
| def feature_dim(self): |
| return self.curation.feature_dim |
|
|
| def observe(self, emb, **context): |
| a_out = self.association(emb, **context) |
| c_out = self.curation.curate_full(a_out, emb=emb, **context) |
| return { |
| 'embedding': emb, 'features': c_out['features'], |
| 'triangulation': a_out['distances'], |
| 'cos_to_anchors': a_out['cos_to_anchors'], |
| 'nearest': a_out['nearest'], |
| 'assignment': a_out['assignment'], |
| 'patchwork': c_out['patchwork'], 'bridge': c_out['bridge'], |
| } |
|
|
| def forward(self, emb, **context): |
| return self.observe(emb, **context) |
|
|
|
|
| |
| |
| |
|
|
| def pairwise_distances_squared(points): |
| """Batched pairwise squared distances. (B, N, D) β (B, N, N).""" |
| gram = torch.bmm(points, points.transpose(1, 2)) |
| diag = gram.diagonal(dim1=-2, dim2=-1) |
| return diag.unsqueeze(2) + diag.unsqueeze(1) - 2 * gram |
|
|
|
|
| def cayley_menger_det(points): |
| """Cayley-Menger signed volumeΒ² for simplices. (B, K, D) β (B,). |
| |
| K = number of vertices (k+1 for a k-simplex). |
| Sign-corrected: positive = valid non-degenerate simplex. |
| """ |
| B, K, D = points.shape |
| d2 = pairwise_distances_squared(points) |
| M = torch.zeros(B, K + 1, K + 1, device=points.device, dtype=points.dtype) |
| M[:, 0, 1:] = 1.0 |
| M[:, 1:, 0] = 1.0 |
| M[:, 1:, 1:] = d2 |
| raw = torch.linalg.det(M) |
| k = K - 1 |
| sign = (-1.0) ** (k + 1) |
| return sign * raw |
|
|
|
|
| def anchor_neighborhood_cm(anchors, n_neighbors=3): |
| """Precompute per-anchor CM quality from local neighborhood geometry. |
| |
| Position-independent. O(A) determinant computations on small matrices. |
| Each anchor forms a simplex with its k nearest neighbor anchors. |
| The CM determinant measures local geometric quality β high volume means |
| the anchor neighborhood is well-conditioned for triangulation. |
| |
| Args: |
| anchors: (A, D) normalized anchor positions on S^(d-1) |
| n_neighbors: neighbors per simplex |
| |
| Returns: |
| quality: (A,) signed log-magnitude CM quality per anchor |
| nn_idx: (A, n_neighbors) neighbor indices |
| """ |
| A, D = anchors.shape |
| dists = torch.cdist(anchors.unsqueeze(0), anchors.unsqueeze(0)).squeeze(0) |
| |
| self_mask = torch.eye(A, device=anchors.device, dtype=anchors.dtype) * 1e12 |
| dists = dists + self_mask |
| _, nn_idx = dists.topk(n_neighbors, largest=False) |
|
|
| |
| K = n_neighbors + 1 |
| simplices = torch.zeros(A, K, D, device=anchors.device, dtype=anchors.dtype) |
| simplices[:, 0] = anchors |
| for j in range(n_neighbors): |
| simplices[:, j + 1] = anchors[nn_idx[:, j]] |
|
|
| dets = cayley_menger_det(simplices) |
| sign = dets.sign() |
| log_mag = torch.log(dets.abs() + 1e-12) |
| return sign * log_mag, nn_idx |
|
|
|
|
| |
| |
| |
|
|
| class CMValidatedGate(nn.Module): |
| """Anchor gate based on Cayley-Menger validity. |
| |
| Efficient for transformer scale: anchor CM quality is precomputed O(AΒ²), |
| then combined with per-position proximity features through a learned gate. |
| |
| The gate starts OPEN (bias=+2, sigmoidβ0.88) and learns to CLOSE on |
| geometrically invalid configurations. Architecture-before-loss: the gate |
| suppresses degenerate measurements structurally, not through a loss signal. |
| |
| Gate features per (position, anchor): |
| - anchor_cm_quality: CM volume of anchor's local neighborhood (position-independent) |
| - cos_to_anchor: cosine similarity (position-dependent) |
| - distance_rank: normalized rank of this anchor by proximity (position-dependent) |
| |
| Args: |
| n_anchors: number of constellation anchors |
| n_neighbors: neighbors for CM simplex computation |
| """ |
| def __init__(self, n_anchors, n_neighbors=3): |
| super().__init__() |
| self.n_anchors = n_anchors |
| self.n_neighbors = n_neighbors |
|
|
| |
| self.gate_proj = nn.Sequential( |
| nn.Linear(3, 16), |
| nn.GELU(), |
| nn.Linear(16, 1), |
| ) |
| |
| nn.init.zeros_(self.gate_proj[2].weight) |
| nn.init.constant_(self.gate_proj[2].bias, 2.0) |
|
|
| def forward(self, embedding, anchors, tri): |
| """Compute per-(position, anchor) gate values. |
| |
| Args: |
| embedding: (N, D) β positions on S^(d-1), where N = B*L |
| anchors: (A, D) β normalized anchor positions (DETACHED by caller) |
| tri: (N, A) β triangulation distances (1 - cos) |
| |
| Returns: |
| gate_values: (N, A) in [0, 1] β per-anchor validity gate |
| gate_info: dict with diagnostics |
| """ |
| N, A = tri.shape |
|
|
| |
| with torch.no_grad(): |
| anchor_cm, nn_idx = anchor_neighborhood_cm(anchors, self.n_neighbors) |
| |
| cm_std = anchor_cm.std().clamp(min=1e-8) |
| anchor_cm_norm = (anchor_cm - anchor_cm.mean()) / cm_std |
|
|
| |
| cos_sim = 1.0 - tri |
|
|
| |
| ranks = tri.argsort(dim=-1).argsort(dim=-1).float() |
| ranks = ranks / max(A - 1, 1) |
|
|
| |
| features = torch.stack([ |
| anchor_cm_norm.unsqueeze(0).expand(N, -1), |
| cos_sim, |
| ranks, |
| ], dim=-1) |
|
|
| gate_values = torch.sigmoid(self.gate_proj(features).squeeze(-1)) |
|
|
| |
| with torch.no_grad(): |
| active = (gate_values > 0.5).float().sum(-1).mean() |
| cm_positive_frac = (anchor_cm > 0).float().mean() |
| gate_mean = gate_values.mean() |
|
|
| gate_info = { |
| 'active': active, |
| 'gate_mean': gate_mean, |
| 'cm_positive_frac': cm_positive_frac, |
| 'anchor_cm': anchor_cm.detach(), |
| } |
|
|
| return gate_values, gate_info |
|
|
|
|
| |
| |
| |
|
|
| class GeoResidualBank(nn.Module): |
| """Cross-stream contrastive memory bank (CLIP-style). |
| |
| Aligns content (Stream A CLS) and geometry (geo_residual CLS) |
| through contrastive learning. Same sample's content and geometry |
| should match; different samples' should not. |
| |
| Bank stores projected geo_residual keys from recent batches. |
| Query is projected content CLS from current batch. |
| Positive pair: (content_i, geometry_i) from same sample. |
| Negatives: geometry from bank. |
| |
| Gradient flows through BOTH streams: |
| - Content CLS β transformer β input (learns distinctive content) |
| - Geo residual CLS β geo_proj β patchwork β CM gate β constellation |
| (learns to observe what content finds relevant) |
| |
| Args: |
| bank_size: number of entries in the queue |
| proj_dim: shared projection dimension for content and geometry |
| temperature: InfoNCE temperature |
| """ |
| def __init__(self, proj_dim, bank_size=4096, temperature=0.1): |
| super().__init__() |
| self.proj_dim = proj_dim |
| self.bank_size = bank_size |
| self.temperature = temperature |
|
|
| |
| self.register_buffer('queue', torch.randn(bank_size, proj_dim)) |
| self.queue = F.normalize(self.queue, dim=-1) |
| self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long)) |
|
|
| @torch.no_grad() |
| def enqueue(self, keys): |
| """Add projected geo keys to queue. Called AFTER backward. |
| Args: |
| keys: (B, proj_dim) normalized projected geo_residual CLS |
| """ |
| B = keys.shape[0] |
| ptr = int(self.queue_ptr.item()) |
| if ptr + B <= self.bank_size: |
| self.queue[ptr:ptr + B] = keys |
| else: |
| overflow = (ptr + B) - self.bank_size |
| self.queue[ptr:] = keys[:B - overflow] |
| self.queue[:overflow] = keys[B - overflow:] |
| self.queue_ptr[0] = (ptr + B) % self.bank_size |
|
|
| def forward(self, content_proj, geo_proj): |
| """Cross-stream InfoNCE: content queries vs geometry keys. |
| |
| Args: |
| content_proj: (B, proj_dim) β projected content CLS (LIVE, has grad) |
| geo_proj: (B, proj_dim) β projected geo_residual CLS (LIVE, has grad) |
| |
| Returns: |
| loss: scalar InfoNCE loss |
| acc: top-1 retrieval accuracy (diagnostic) |
| """ |
| q = F.normalize(content_proj, dim=-1) |
| k_pos = F.normalize(geo_proj, dim=-1) |
| k_neg = self.queue.clone().detach() |
|
|
| |
| pos_logits = (q * k_pos).sum(dim=-1, keepdim=True) / self.temperature |
|
|
| |
| neg_logits = q @ k_neg.T / self.temperature |
|
|
| |
| logits = torch.cat([pos_logits, neg_logits], dim=1) |
| labels = torch.zeros(q.shape[0], dtype=torch.long, device=q.device) |
|
|
| loss = F.cross_entropy(logits, labels) |
|
|
| with torch.no_grad(): |
| acc = (logits.argmax(dim=1) == 0).float().mean() |
|
|
| return loss, acc |
|
|
|
|
| |
| |
| |
|
|
| class FiLMLayer(TorchComponent): |
| """Feature-wise Linear Modulation. Proven in Ryan Spearman. |
| Identity-initialized: Ξ³=1, Ξ²=0 at init. |
| """ |
| def __init__(self, name, feature_dim, context_dim): |
| super().__init__(name) |
| self.to_gamma = nn.Linear(context_dim, feature_dim) |
| self.to_beta = nn.Linear(context_dim, feature_dim) |
| nn.init.zeros_(self.to_gamma.weight); nn.init.ones_(self.to_gamma.bias) |
| nn.init.zeros_(self.to_beta.weight); nn.init.zeros_(self.to_beta.bias) |
|
|
| def forward(self, x, ctx): |
| return self.to_gamma(ctx) * x + self.to_beta(ctx) |
|
|
|
|
| class CayleyOrthogonal(TorchComponent): |
| """Guaranteed SO(d) rotation via Cayley map. det(Q) = 1 always.""" |
| def __init__(self, name, dim): |
| super().__init__(name) |
| self.dim = dim |
| self.A_upper = nn.Parameter(torch.zeros(dim * (dim - 1) // 2) * 0.01) |
| idx = torch.triu_indices(dim, dim, offset=1) |
| self.register_buffer('_triu_row', idx[0], persistent=False) |
| self.register_buffer('_triu_col', idx[1], persistent=False) |
| self.register_buffer('_eye', torch.eye(dim), persistent=False) |
|
|
| def get_rotation(self): |
| d = self.dim |
| A = torch.zeros(d, d, device=self.A_upper.device, dtype=self.A_upper.dtype) |
| A[self._triu_row, self._triu_col] = self.A_upper |
| A = A - A.T |
| return torch.linalg.solve(self._eye + A, self._eye - A) |
|
|
| def forward(self, x): |
| return x @ self.get_rotation().T |
|
|
|
|
| def quaternion_multiply_batched(q1, q2): |
| """Hamilton product on (B, 4, D) tensors. Fully vectorized.""" |
| w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3] |
| w2, x2, y2, z2 = q2[:, 0], q2[:, 1], q2[:, 2], q2[:, 3] |
| return torch.stack([ |
| w1*w2 - x1*x2 - y1*y2 - z1*z2, |
| w1*x2 + x1*w2 + y1*z2 - z1*y2, |
| w1*y2 - x1*z2 + y1*w2 + z1*x2, |
| w1*z2 + x1*y2 - y1*x2 + z1*w2, |
| ], dim=1) |
|
|
|
|
| class QuaternionCompose(TorchComponent): |
| """Four-arm Hamilton product composition. Proven in GeoQuat head. |
| Fully vectorized: single batched Hamilton product, no Python loops. |
| """ |
| def __init__(self, name, input_dim, quat_dim=64): |
| super().__init__(name) |
| self.quat_dim = quat_dim |
| self.proj_w = nn.Linear(input_dim, quat_dim) |
| self.proj_i = nn.Linear(input_dim, quat_dim) |
| self.proj_j = nn.Linear(input_dim, quat_dim) |
| self.proj_k = nn.Linear(input_dim, quat_dim) |
| self.rotation = nn.Parameter(torch.randn(1, 4, quat_dim) * 0.1) |
|
|
| @property |
| def output_dim(self): |
| return self.quat_dim * 4 |
|
|
| def forward(self, arm_w, arm_i, arm_j, arm_k): |
| shape = arm_w.shape[:-1] |
| D = arm_w.shape[-1] |
| flat = arm_w.dim() > 2 |
| if flat: |
| arm_w = arm_w.reshape(-1, D); arm_i = arm_i.reshape(-1, D) |
| arm_j = arm_j.reshape(-1, D); arm_k = arm_k.reshape(-1, D) |
| q = torch.stack([self.proj_w(arm_w), self.proj_i(arm_i), |
| self.proj_j(arm_j), self.proj_k(arm_k)], dim=1) |
| q = q / (q.norm(dim=1, keepdim=True) + 1e-8) |
| r = self.rotation.expand(q.shape[0], -1, -1) |
| r = r / (r.norm(dim=1, keepdim=True) + 1e-8) |
| composed = quaternion_multiply_batched(r, q) |
| composed = composed.reshape(q.shape[0], -1) |
| if flat: |
| composed = composed.reshape(*shape, -1) |
| return composed |
|
|
|
|
| |
| |
| |
|
|
| class ManifoldProjection(TorchComponent): |
| """Input stage: project transformer hidden states to S^(d-1). |
| Per-position, per-layer. L2-normalized to unit hypersphere. |
| """ |
| def __init__(self, name, d_model, manifold_dim): |
| super().__init__(name) |
| self.proj = nn.Linear(d_model, manifold_dim) |
| self.norm = nn.LayerNorm(manifold_dim) |
|
|
| def forward(self, hidden_states): |
| h = self.norm(self.proj(hidden_states)) |
| return F.normalize(h, dim=-1) |
|
|
|
|
| class PositionGeometricContext(TorchComponent): |
| """Curation stage: 4-stream fusion β FiLM context. |
| |
| Four streams: |
| anchor: cos_to_anchors + assignment + triangulation β WHERE on the manifold |
| structural: patchwork + embedding β WHAT the local geometry looks like |
| history: geo_residual from previous layers β WHAT prior layers observed |
| quality: CM gate values per anchor β HOW TRUSTWORTHY is this observation |
| |
| The quality stream gives FiLM direct knowledge of which anchors formed |
| valid simplices. This is not a scalar β the full (N, A) gate profile |
| tells the context WHICH directions on the manifold are reliable. |
| """ |
| def __init__(self, name, n_anchors, pw_dim, manifold_dim, context_dim): |
| super().__init__(name) |
| self.context_dim = context_dim |
| self.pw_dim = pw_dim |
|
|
| |
| self.anchor_mlp = nn.Sequential( |
| nn.Linear(n_anchors * 3, context_dim), nn.GELU(), nn.LayerNorm(context_dim)) |
| |
| self.struct_mlp = nn.Sequential( |
| nn.Linear(pw_dim + manifold_dim, context_dim), nn.GELU(), nn.LayerNorm(context_dim)) |
| |
| self.history_mlp = nn.Sequential( |
| nn.Linear(pw_dim, context_dim), nn.GELU(), nn.LayerNorm(context_dim)) |
| |
| self.quality_mlp = nn.Sequential( |
| nn.Linear(n_anchors, context_dim), nn.GELU(), nn.LayerNorm(context_dim)) |
|
|
| |
| self.fuse = nn.Sequential( |
| nn.Linear(context_dim * 4, context_dim), nn.GELU(), nn.LayerNorm(context_dim)) |
|
|
| def forward(self, obs_dict, gate_values=None, geo_residual=None): |
| """ |
| Args: |
| obs_dict: from decomposed association + gated curation |
| gate_values: (N, A) CM gate values per anchor, or None |
| geo_residual: (N, pw_dim) accumulated context, or None for first layer |
| Returns: |
| (N, context_dim) geometric context for FiLM |
| """ |
| anchor_feats = torch.cat([ |
| obs_dict['cos_to_anchors'], |
| obs_dict['assignment'], |
| obs_dict['triangulation'], |
| ], dim=-1) |
| struct_feats = torch.cat([ |
| obs_dict['patchwork'], |
| obs_dict['embedding'], |
| ], dim=-1) |
|
|
| a = self.anchor_mlp(anchor_feats) |
| s = self.struct_mlp(struct_feats) |
| h = self.history_mlp(geo_residual) if geo_residual is not None else torch.zeros_like(a) |
| q = self.quality_mlp(gate_values) if gate_values is not None else torch.zeros_like(a) |
|
|
| return self.fuse(torch.cat([a, s, h, q], dim=-1)) |
|
|
|
|
| class GeometricAttention(TorchComponent): |
| """Attention with FiLM from curated constellation. Stream B. |
| FiLM modulates Q,K BEFORE attention. V stays unmodulated. |
| """ |
| def __init__(self, name, d_model, n_heads=8, context_dim=128, dropout=0.1): |
| super().__init__(name) |
| self.d_model = d_model |
| self.n_heads = n_heads |
| self.head_dim = d_model // n_heads |
| self.scale = self.head_dim ** -0.5 |
|
|
| self.w_q = nn.Linear(d_model, d_model) |
| self.w_k = nn.Linear(d_model, d_model) |
| self.w_v = nn.Linear(d_model, d_model) |
| self.w_o = nn.Linear(d_model, d_model) |
| self.dropout = nn.Dropout(dropout) |
|
|
| self.film_q = FiLMLayer(f'{name}_film_q', d_model, context_dim) |
| self.film_k = FiLMLayer(f'{name}_film_k', d_model, context_dim) |
| self.norm = nn.LayerNorm(d_model) |
|
|
| self.ffn1 = nn.Linear(d_model, d_model * 4) |
| self.film_ffn = FiLMLayer(f'{name}_film_ffn', d_model * 4, context_dim) |
| self.ffn2 = nn.Linear(d_model * 4, d_model) |
| self.ffn_drop = nn.Dropout(dropout) |
| self.ffn_norm = nn.LayerNorm(d_model) |
|
|
| def forward(self, x, geo_ctx, attn_mask=None, key_padding_mask=None): |
| B, L, D = x.shape |
| H, HD = self.n_heads, self.head_dim |
|
|
| Q = self.film_q(self.w_q(x), geo_ctx) |
| K = self.film_k(self.w_k(x), geo_ctx) |
| V = self.w_v(x) |
|
|
| Q = Q.view(B, L, H, HD).transpose(1, 2) |
| K = K.view(B, L, H, HD).transpose(1, 2) |
| V = V.view(B, L, H, HD).transpose(1, 2) |
|
|
| scores = (Q @ K.transpose(-2, -1)) * self.scale |
| if attn_mask is not None: |
| scores = scores + attn_mask |
| if key_padding_mask is not None: |
| scores = scores.masked_fill( |
| key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf')) |
| attn_out = (self.dropout(F.softmax(scores, dim=-1)) @ V) |
| attn_out = attn_out.transpose(1, 2).reshape(B, L, D) |
| x = self.norm(x + self.w_o(attn_out)) |
|
|
| h = F.gelu(self.ffn1(x)) |
| h = self.film_ffn(h, geo_ctx) |
| x = self.ffn_norm(x + self.ffn_drop(self.ffn2(h))) |
| return x |
|
|
|
|
| class ContentAttention(TorchComponent): |
| """Standard self-attention. Stream A. No geometric conditioning.""" |
| def __init__(self, name, d_model, n_heads=8, dropout=0.1): |
| super().__init__(name) |
| self.attn = nn.MultiheadAttention( |
| d_model, n_heads, dropout=dropout, batch_first=True) |
| self.norm = nn.LayerNorm(d_model) |
| self.ffn = nn.Sequential( |
| nn.Linear(d_model, d_model * 4), nn.GELU(), |
| nn.Linear(d_model * 4, d_model), nn.Dropout(dropout)) |
| self.ffn_norm = nn.LayerNorm(d_model) |
|
|
| def forward(self, x, attn_mask=None, key_padding_mask=None): |
| a, _ = self.attn(x, x, x, attn_mask=attn_mask, |
| key_padding_mask=key_padding_mask) |
| x = self.norm(x + a) |
| x = self.ffn_norm(x + self.ffn(x)) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class GeometricTransformerLayer(BaseTower): |
| """One layer of the geometric transformer (CM validated). |
| |
| Pipeline per layer: |
| 1. ManifoldProjection: h β emb on S^(d-1) |
| 2. Association: emb β raw triangulation, cos, assignment |
| 3. CMValidatedGate: per-anchor CM validity β gate_values |
| 4. Gated curation: patchwork reads tri * gate_values |
| 5. PositionGeometricContext: 4 streams β FiLM context |
| 6. ContentAttention (Stream A): standard MHA |
| 7. GeometricAttention (Stream B): FiLM(Q,K | geo_ctx) |
| 8. CayleyOrthogonal: align B β A |
| 9. QuaternionCompose: w=A, i=aligned_B, j=A-B, k=A*B |
| 10. Decode + gated residual |
| 11. CM-conditioned geometric residual accumulation |
| |
| The observer is DECOMPOSED: association and curation are called |
| separately with the CM gate inserted between them. The gate |
| suppresses degenerate anchor measurements before the patchwork |
| reads them. The patchwork only interprets validated geometry. |
| |
| The geometric residual is accumulated using CM quality as the |
| write weight β no learned gate. Positions with high-quality |
| simplex observations contribute more. Positions in degenerate |
| regions contribute less. |
| """ |
| def __init__(self, name, d_model, n_heads=8, n_anchors=32, |
| manifold_dim=256, n_comp=8, d_comp=32, |
| context_dim=128, quat_dim=64, dropout=0.1, |
| cm_neighbors=3): |
| super().__init__(name) |
| self.d_model = d_model |
| self.n_anchors = n_anchors |
|
|
| |
| self.attach('projection', ManifoldProjection( |
| f'{name}_proj', d_model, manifold_dim)) |
|
|
| |
| self.attach('observer', ConstellationObserver( |
| dim=manifold_dim, n_anchors=n_anchors, |
| n_comp=n_comp, d_comp=d_comp)) |
|
|
| |
| self.attach('cm_gate', CMValidatedGate( |
| n_anchors=n_anchors, n_neighbors=cm_neighbors)) |
|
|
| |
| pw_dim = self['observer'].curation.patchwork.output_dim |
| self.attach('context', PositionGeometricContext( |
| f'{name}_ctx', n_anchors, pw_dim, manifold_dim, context_dim)) |
|
|
| |
| self.attach('content', ContentAttention( |
| f'{name}_content', d_model, n_heads, dropout)) |
|
|
| |
| self.attach('geometric', GeometricAttention( |
| f'{name}_geo', d_model, n_heads, context_dim, dropout)) |
|
|
| |
| self.attach('rotation', CayleyOrthogonal(f'{name}_cayley', d_model)) |
|
|
| |
| self.attach('compose', QuaternionCompose( |
| f'{name}_quat', d_model, quat_dim)) |
|
|
| |
| self.attach('decode', nn.Sequential( |
| nn.Linear(quat_dim * 4, d_model), nn.GELU(), nn.LayerNorm(d_model))) |
| self.attach('gate', nn.Sequential( |
| nn.Linear(d_model * 2, d_model), nn.Sigmoid())) |
|
|
| |
| self._pw_dim = pw_dim |
| self.attach('geo_proj', nn.Sequential( |
| nn.Linear(pw_dim, pw_dim), nn.LayerNorm(pw_dim))) |
|
|
| def forward(self, x, geo_residual=None, attn_mask=None, key_padding_mask=None): |
| """ |
| Args: |
| x: (B, L, D) input hidden states |
| geo_residual: (B, L, pw_dim) accumulated geometric context, |
| or None for first layer |
| |
| Returns: |
| x_out: (B, L, D) transformed hidden states |
| geo_residual_out: (B, L, pw_dim) updated geometric residual |
| geo_state: dict with full geometric state + CM diagnostics |
| """ |
| B, L, D = x.shape |
|
|
| |
| emb = self['projection'](x) |
| emb_flat = emb.reshape(B * L, -1) |
|
|
| |
| a_out = self['observer'].association(emb_flat) |
|
|
| |
| anchors_n = F.normalize( |
| self['observer'].association.constellation.anchors, dim=-1) |
| gate_values, gate_info = self['cm_gate']( |
| emb_flat, anchors_n.detach(), a_out['distances']) |
|
|
| |
| a_out_gated = dict(a_out) |
| a_out_gated['distances_weighted'] = a_out['distances'] * gate_values |
| c_out = self['observer'].curation.curate_full(a_out_gated, emb=emb_flat) |
|
|
| |
| obs = { |
| 'embedding': emb_flat, |
| 'triangulation': a_out['distances'], |
| 'cos_to_anchors': a_out['cos_to_anchors'], |
| 'assignment': a_out['assignment'], |
| 'nearest': a_out['nearest'], |
| 'patchwork': c_out['patchwork'], |
| 'bridge': c_out['bridge'], |
| } |
|
|
| |
| geo_res_flat = geo_residual.reshape(B * L, -1) if geo_residual is not None else None |
| geo_ctx_flat = self['context']( |
| obs, gate_values=gate_values, geo_residual=geo_res_flat) |
| geo_ctx = geo_ctx_flat.reshape(B, L, -1) |
|
|
| |
| a_out_stream = self['content']( |
| x, attn_mask=attn_mask, key_padding_mask=key_padding_mask) |
|
|
| |
| b_out = self['geometric']( |
| x, geo_ctx, attn_mask=attn_mask, key_padding_mask=key_padding_mask) |
|
|
| |
| b_aligned = self['rotation'](b_out) |
|
|
| |
| composed = self['compose']( |
| arm_w=a_out_stream, arm_i=b_aligned, |
| arm_j=a_out_stream - b_aligned, arm_k=a_out_stream * b_aligned) |
|
|
| |
| decoded = self['decode'](composed) |
| g = self['gate'](torch.cat([x, decoded], dim=-1)) |
| x_out = g * decoded + (1 - g) * x |
|
|
| |
| |
| |
| |
| pw_validated = c_out['patchwork'].reshape(B, L, -1) |
| cm_quality = gate_values.mean(dim=-1).reshape(B, L, 1) |
| geo_update = self['geo_proj'](pw_validated) |
|
|
| if geo_residual is None: |
| geo_residual_out = cm_quality * geo_update |
| else: |
| geo_residual_out = geo_residual + cm_quality * geo_update |
|
|
| |
| def _unflatten(t): |
| if t is None: |
| return None |
| if t.dim() == 1: |
| return t.reshape(B, L) |
| return t.reshape(B, L, *t.shape[1:]) |
|
|
| geo_state = { |
| 'embedding': emb, |
| 'geo_ctx': geo_ctx, |
| 'triangulation': _unflatten(a_out['distances']), |
| 'cos_to_anchors': _unflatten(a_out['cos_to_anchors']), |
| 'assignment': _unflatten(a_out['assignment']), |
| 'nearest': _unflatten(a_out['nearest']), |
| 'patchwork': _unflatten(c_out['patchwork']), |
| 'bridge': _unflatten(c_out['bridge']), |
| 'gate_values': _unflatten(gate_values), |
| 'gate_info': gate_info, |
| 'cm_quality': cm_quality, |
| 'content': a_out_stream, |
| 'geometric': b_out, |
| 'composed': composed, |
| 'geo_residual': geo_residual_out, |
| } |
|
|
| return x_out, geo_residual_out, geo_state |
|
|
|
|
| |
| |
| |
|
|
| class GeometricTransformer(BaseTower): |
| """Geometric Transformer β CM-validated dual-stream. |
| |
| Stack of GeometricTransformerLayers with: |
| - CM-gated observation at every layer |
| - Cross-layer Cayley rotation on hidden states (not geo_residual) |
| - Built-in geometric regularization via geometric_losses() |
| """ |
| def __init__(self, name, d_model=512, n_heads=8, n_layers=4, |
| n_anchors=32, manifold_dim=256, n_comp=8, d_comp=32, |
| context_dim=128, quat_dim=64, dropout=0.1, |
| cross_layer_rotation=True, cm_neighbors=3, |
| nce_bank_size=4096, nce_temperature=0.1, |
| vocab_size=None, max_seq_len=2048): |
| super().__init__(name) |
| self.d_model = d_model |
| self.n_layers = n_layers |
| self.n_anchors = n_anchors |
| self._pw_dim = n_comp * d_comp |
|
|
| if vocab_size is not None: |
| self.attach('embed', nn.Embedding(vocab_size, d_model)) |
| self.attach('pos_embed', nn.Embedding(max_seq_len, d_model)) |
| self.attach('head', nn.Linear(d_model, vocab_size, bias=False)) |
|
|
| for i in range(n_layers): |
| self.attach(f'layer_{i}', GeometricTransformerLayer( |
| f'{name}_L{i}', d_model, n_heads, n_anchors, |
| manifold_dim, n_comp, d_comp, context_dim, quat_dim, |
| dropout, cm_neighbors)) |
|
|
| if cross_layer_rotation and n_layers > 1: |
| for i in range(n_layers - 1): |
| self.attach(f'cross_rot_{i}', CayleyOrthogonal( |
| f'{name}_xrot_{i}', d_model)) |
|
|
| self.attach('final_norm', nn.LayerNorm(d_model)) |
|
|
| |
| |
| if nce_bank_size > 0: |
| nce_proj_dim = 128 |
| self.attach('nce_content_proj', nn.Sequential( |
| nn.Linear(d_model, nce_proj_dim), |
| nn.GELU(), |
| nn.Linear(nce_proj_dim, nce_proj_dim), |
| )) |
| self.attach('nce_geo_proj', nn.Sequential( |
| nn.Linear(self._pw_dim, nce_proj_dim), |
| nn.GELU(), |
| nn.Linear(nce_proj_dim, nce_proj_dim), |
| )) |
| self.attach('nce_bank', GeoResidualBank( |
| nce_proj_dim, bank_size=nce_bank_size, |
| temperature=nce_temperature)) |
|
|
| self._config = dict( |
| d_model=d_model, n_heads=n_heads, n_layers=n_layers, |
| n_anchors=n_anchors, manifold_dim=manifold_dim, |
| n_comp=n_comp, d_comp=d_comp, context_dim=context_dim, |
| quat_dim=quat_dim, dropout=dropout, |
| cross_layer_rotation=cross_layer_rotation, |
| cm_neighbors=cm_neighbors, vocab_size=vocab_size, |
| nce_bank_size=nce_bank_size, nce_temperature=nce_temperature, |
| ) |
|
|
| @property |
| def config(self): |
| return self._config.copy() |
|
|
| def geometric_losses(self, cv_target=0.215, cv_weight=0.1, spread_weight=0.01): |
| """Compute geometric regularization from current anchor geometry. |
| |
| These losses maintain the constellation in the regime where |
| CM validation, patchwork interpretation, and the full observation |
| pipeline produce meaningful results. |
| |
| CV loss: push anchor coefficient of variation toward pentachoron |
| band (0.20-0.23). This is where CM computation has maximal |
| discriminative power β anchors are neither too uniform (CVβ0, |
| CM uninformative) nor too clustered (CV>0.3, degenerate simplices). |
| |
| Spread loss: penalize positive cosine similarity between anchors. |
| Prevents collapse where multiple anchors occupy the same region, |
| creating redundant measurements and wasting patchwork capacity. |
| |
| Returns: |
| dict with 'cv', 'spread', 'geo_total' loss tensors |
| """ |
| total_cv = torch.tensor(0.0) |
| total_spread = torch.tensor(0.0) |
| n = 0 |
|
|
| for i in range(self.n_layers): |
| layer = self[f'layer_{i}'] |
| anchors = layer['observer'].association.constellation.anchors |
| anchors_n = F.normalize(anchors, dim=-1) |
| A = anchors_n.shape[0] |
|
|
| |
| if n == 0: |
| total_cv = total_cv.to(anchors.device) |
| total_spread = total_spread.to(anchors.device) |
|
|
| |
| cos = anchors_n @ anchors_n.T |
| idx = torch.triu_indices(A, A, offset=1, device=cos.device) |
| pairwise_dist = 1.0 - cos[idx[0], idx[1]] |
| cv = pairwise_dist.std() / (pairwise_dist.mean() + 1e-8) |
| total_cv = total_cv + (cv - cv_target).pow(2) |
|
|
| |
| mask = ~torch.eye(A, dtype=torch.bool, device=cos.device) |
| total_spread = total_spread + F.relu(cos[mask]).mean() |
|
|
| n += 1 |
|
|
| losses = {} |
| if n > 0: |
| losses['cv'] = cv_weight * total_cv / n |
| losses['spread'] = spread_weight * total_spread / n |
| losses['geo_total'] = losses['cv'] + losses['spread'] |
| return losses |
|
|
| def infonce_loss(self, cls_index=0): |
| """Cross-stream contrastive: content queries against decoupled geometry. |
| |
| The constellation provides a STABLE geometric reference frame. |
| The content stream needs discriminative correction. |
| The InfoNCE targets weaker content representations by measuring |
| them against the constellation's observation. |
| |
| Gradient path (info-side only): |
| - nce_content_proj β hidden_cls β transformer β input (LIVE) |
| - nce_geo_proj β learns to read detached residual (LIVE proj, FROZEN input) |
| - geo_residual β constellation/patchwork/geo_proj (DETACHED β decoupled) |
| |
| The constellation's anchors never see NCE gradient. |
| Both projection heads learn from InfoNCE to find shared space. |
| Content stream receives corrective gradient at weak positions. |
| |
| Returns: |
| dict with 'nce': loss tensor, 'nce_acc': retrieval accuracy |
| """ |
| if not self.has('nce_bank'): |
| return {} |
|
|
| hidden = getattr(self, '_last_hidden', None) |
| geo_residual = getattr(self, '_last_geo_residual', None) |
| if hidden is None or geo_residual is None: |
| return {} |
|
|
| |
| content_cls = self['nce_content_proj'](hidden[:, cls_index]) |
|
|
| |
| |
| geo_cls = self['nce_geo_proj'](geo_residual[:, cls_index].detach()) |
|
|
| loss, acc = self['nce_bank'](content_cls, geo_cls) |
| return {'nce': loss, 'nce_acc': acc} |
|
|
| @torch.no_grad() |
| def update_nce_bank(self, cls_index=0): |
| """Enqueue projected geo keys into bank. Call AFTER backward.""" |
| if not self.has('nce_bank') or not self.has('nce_geo_proj'): |
| return |
|
|
| geo_residual = getattr(self, '_last_geo_residual', None) |
| if geo_residual is None: |
| return |
|
|
| geo_cls = self['nce_geo_proj'](geo_residual[:, cls_index].detach()) |
| self['nce_bank'].enqueue(F.normalize(geo_cls, dim=-1)) |
|
|
| def anchor_diagnostics(self): |
| """Per-layer anchor health diagnostics. Call for monitoring.""" |
| diag = {} |
| for i in range(self.n_layers): |
| layer = self[f'layer_{i}'] |
| anchors = layer['observer'].association.constellation.anchors |
| anchors_n = F.normalize(anchors.detach(), dim=-1) |
| A = anchors_n.shape[0] |
|
|
| cos = anchors_n @ anchors_n.T |
| idx = torch.triu_indices(A, A, offset=1, device=cos.device) |
| pairwise = 1.0 - cos[idx[0], idx[1]] |
| cv = (pairwise.std() / (pairwise.mean() + 1e-8)).item() |
|
|
| |
| with torch.no_grad(): |
| anchor_cm, _ = anchor_neighborhood_cm( |
| anchors_n, layer['cm_gate'].n_neighbors) |
|
|
| diag[f'layer_{i}'] = { |
| 'anchor_cv': cv, |
| 'mean_pairwise_dist': pairwise.mean().item(), |
| 'min_pairwise_dist': pairwise.min().item(), |
| 'cm_positive_frac': (anchor_cm > 0).float().mean().item(), |
| 'cm_mean': anchor_cm.mean().item(), |
| 'cm_std': anchor_cm.std().item(), |
| } |
| return diag |
|
|
| def param_report(self): |
| total = 0 |
| name = getattr(self, '_tower_name', self.__class__.__name__) |
| print(f"\n {name} β parameter report (CM-validated)") |
| print(f" {'Component':<35s} {'Params':>12s}") |
| print(f" {'β'*35} {'β'*12}") |
| for cname, module in self.named_children(): |
| n = sum(p.numel() for p in module.parameters()) |
| total += n |
| print(f" {cname:<35s} {n:>12,}") |
| print(f" {'β'*35} {'β'*12}") |
| print(f" {'TOTAL':<35s} {total:>12,}") |
| return total |
|
|
| def forward(self, x, attn_mask=None, key_padding_mask=None, |
| return_geo_state=False): |
| """ |
| Returns: |
| out: (B, L, D) transformed hidden states (or logits if head attached) |
| geo_states: list of per-layer geo_state dicts (if return_geo_state) |
| |
| Side effect: |
| self._last_geo_residual is set to the final geo_residual (B, L, pw_dim) |
| for use by infonce_loss() and update_nce_bank() without changing the return API. |
| """ |
| if self.has('embed') and x.dtype in (torch.long, torch.int32, torch.int64): |
| pos = torch.arange(x.shape[1], device=x.device) |
| x = self['embed'](x) + self['pos_embed'](pos) |
|
|
| geo_states = [] |
| has_xrot = self.has('cross_rot_0') |
| geo_residual = None |
|
|
| for i in range(self.n_layers): |
| x, geo_residual, geo_state = self[f'layer_{i}']( |
| x, geo_residual=geo_residual, |
| attn_mask=attn_mask, key_padding_mask=key_padding_mask) |
| if return_geo_state: |
| geo_states.append(geo_state) |
| if has_xrot and i < self.n_layers - 1: |
| x = self[f'cross_rot_{i}'](x) |
| |
|
|
| |
| self._last_geo_residual = geo_residual |
| self._last_hidden = x |
|
|
| x = self['final_norm'](x) |
| if self.has('head'): |
| x = self['head'](x) |
|
|
| return (x, geo_states) if return_geo_state else x |
|
|
| |
|
|
| def _run_view(self, x, attn_mask=None, key_padding_mask=None): |
| """Run one view through the full pipeline. |
| |
| Returns: |
| features: (B, L, D) transformed hidden states (post-norm) |
| geo_states: list of per-layer geo_state dicts |
| """ |
| geo_states = [] |
| has_xrot = self.has('cross_rot_0') |
| geo_residual = None |
|
|
| if self.has('embed') and x.dtype in (torch.long, torch.int32, torch.int64): |
| pos = torch.arange(x.shape[1], device=x.device) |
| x = self['embed'](x) + self['pos_embed'](pos) |
|
|
| for i in range(self.n_layers): |
| x, geo_residual, geo_state = self[f'layer_{i}']( |
| x, geo_residual=geo_residual, |
| attn_mask=attn_mask, key_padding_mask=key_padding_mask) |
| geo_states.append(geo_state) |
| if has_xrot and i < self.n_layers - 1: |
| x = self[f'cross_rot_{i}'](x) |
|
|
| x = self['final_norm'](x) |
| return x, geo_states |
|
|
| def forward_paired(self, x1, x2, cls_index=0, |
| attn_mask=None, key_padding_mask=None): |
| """Dual-view forward for observer loss training. |
| |
| Runs both views through the full CM-gated pipeline, extracts |
| CLS-position geometric state from the final layer, and packages |
| into the observe_paired output format expected by observer_loss(). |
| |
| Args: |
| x1, x2: (B, L, D) two views of input hidden states |
| cls_index: position index for image-level outputs (default 0) |
| |
| Returns: |
| output dict matching observer_loss spec: |
| embedding, embedding_aug, patchwork1, patchwork1_aug, |
| bridge1, bridge2, assign1, assign2, cos1, tri1, tri2 |
| Plus: features1, features2, geo_states1, geo_states2 |
| """ |
| feat1, gs1 = self._run_view(x1, attn_mask, key_padding_mask) |
| feat2, gs2 = self._run_view(x2, attn_mask, key_padding_mask) |
|
|
| |
| g1 = gs1[-1] |
| g2 = gs2[-1] |
| c = cls_index |
|
|
| return { |
| |
| 'embedding': g1['embedding'][:, c], |
| 'embedding_aug': g2['embedding'][:, c], |
| 'patchwork1': g1['patchwork'][:, c], |
| 'patchwork1_aug': g2['patchwork'][:, c], |
| 'bridge1': g1['bridge'][:, c], |
| 'bridge2': g2['bridge'][:, c], |
| 'assign1': g1['assignment'][:, c], |
| 'assign2': g2['assignment'][:, c], |
| 'cos1': g1['cos_to_anchors'][:, c], |
| 'tri1': g1['triangulation'][:, c], |
| 'tri2': g2['triangulation'][:, c], |
| |
| 'features1': feat1, |
| 'features2': feat2, |
| |
| 'gate_values1': g1['gate_values'][:, c], |
| 'gate_values2': g2['gate_values'][:, c], |
| 'cm_quality1': g1['cm_quality'], |
| 'cm_quality2': g2['cm_quality'], |
| 'geo_states1': gs1, |
| 'geo_states2': gs2, |
| } |
|
|
| def compute_loss(self, output, targets, cls_index=0, |
| w_ce=1.0, head=None, **loss_kwargs): |
| """Three-domain observer loss through the CM-gated pipeline. |
| |
| Follows ConstellationEncoder.compute_loss pattern: |
| observer_loss (geometric + internal) + CE (external) |
| |
| The observer_loss reads patchwork, bridge, assign, tri, cos β |
| all of which flowed through the CM gate during forward_paired. |
| |
| Args: |
| output: dict from forward_paired() |
| targets: (B,) class labels |
| cls_index: which position has the CLS token |
| w_ce: weight on cross-entropy loss |
| head: nn.Module mapping (B, D) β (B, num_classes), or None |
| **loss_kwargs: forwarded to observer_loss (w_nce_pw, w_bridge, etc.) |
| |
| Returns: |
| (total_loss, loss_dict) |
| """ |
| |
| final_layer = self[f'layer_{self.n_layers - 1}'] |
| anchors = final_layer['observer'].association.constellation.anchors |
|
|
| |
| obs_loss, ld = _geolip_observer_loss( |
| output, anchors=anchors, targets=targets, |
| **loss_kwargs) |
|
|
| |
| if head is not None: |
| feat1 = output['features1'][:, cls_index] |
| feat2 = output['features2'][:, cls_index] |
| logits1 = head(feat1) |
| logits2 = head(feat2) |
| l_ce, acc = _geolip_ce_loss_paired(logits1, logits2, targets) |
| ld['ce'], ld['acc'] = l_ce, acc |
| ld['logits'] = logits1 |
| loss = w_ce * l_ce + obs_loss |
| ld['loss_task'] = l_ce.item() |
| else: |
| loss = obs_loss |
|
|
| |
| total_spread = torch.tensor(0.0, device=anchors.device) |
| for i in range(self.n_layers): |
| layer = self[f'layer_{i}'] |
| layer_anchors = layer['observer'].association.constellation.anchors |
| total_spread = total_spread + _geolip_spread_loss(layer_anchors) |
| ld['spread_all_layers'] = total_spread / self.n_layers |
|
|
| ld['loss_observer'] = obs_loss.item() |
| ld['total'] = loss |
| return loss, ld |
|
|
|
|
| |
| |
| |
|
|
| def geo_transformer_esm2(name='geo_esm2', n_layers=6, **kw): |
| """Pre-configured for ESM-2 650M (d=1280).""" |
| return GeometricTransformer(name, d_model=1280, n_heads=16, |
| n_layers=n_layers, n_anchors=32, manifold_dim=256, |
| n_comp=8, d_comp=32, context_dim=128, quat_dim=64, **kw) |
|
|
| def geo_transformer_small(name='geo_small', n_layers=4, **kw): |
| """Small config for prototyping.""" |
| return GeometricTransformer(name, d_model=256, n_heads=8, |
| n_layers=n_layers, n_anchors=16, manifold_dim=128, |
| n_comp=4, d_comp=16, context_dim=64, quat_dim=32, **kw) |
|
|
| def geo_transformer_vision(name='geo_vit', n_layers=4, **kw): |
| """For scatter/SVD vision pipeline (patches as tokens).""" |
| return GeometricTransformer(name, d_model=384, n_heads=8, |
| n_layers=n_layers, n_anchors=32, manifold_dim=128, |
| n_comp=8, d_comp=16, context_dim=64, quat_dim=32, **kw) |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == '__main__': |
| print("Geometric Transformer β CM Validated β Self-Test") |
| print(f" geolip_core available: {_HAS_GEOLIP}") |
| print("=" * 60) |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| |
| model = geo_transformer_small('test_cm', n_layers=2) |
| if hasattr(model, 'network_to'): |
| model.network_to(device=device, strict=False) |
| else: |
| model = model.to(device) |
| total = model.param_report() |
|
|
| |
| B, L, D = 2, 32, 256 |
| x = torch.randn(B, L, D, device=device) |
| out, geos = model(x, return_geo_state=True) |
|
|
| assert out.shape == (B, L, D), f"Expected ({B},{L},{D}), got {out.shape}" |
| assert len(geos) == 2 |
| print(f"\n Input: ({B}, {L}, {D})") |
| print(f" Output: {out.shape}") |
| print(f" Geo states: {len(geos)} layers") |
|
|
| |
| for i, gs in enumerate(geos): |
| gi = gs['gate_info'] |
| cm_q = gs['cm_quality'] |
| gv = gs['gate_values'] |
| print(f"\n Layer {i} CM gate:") |
| print(f" active anchors: {gi['active'].item():.1f} / {model.n_anchors}") |
| print(f" gate mean: {gi['gate_mean'].item():.4f}") |
| print(f" cm_positive_frac: {gi['cm_positive_frac'].item():.3f}") |
| print(f" gate_values: {gv.shape} range=[{gv.min():.3f}, {gv.max():.3f}]") |
| print(f" cm_quality: {cm_q.shape} mean={cm_q.mean():.4f}") |
|
|
| |
| gr0 = geos[0]['geo_residual'] |
| gr1 = geos[1]['geo_residual'] |
| print(f"\n Geo residual stream:") |
| print(f" Layer 0: {gr0.shape} norm={gr0.norm(dim=-1).mean():.4f}") |
| print(f" Layer 1: {gr1.shape} norm={gr1.norm(dim=-1).mean():.4f}") |
|
|
| |
| geo_losses = model.geometric_losses() |
| print(f"\n Geometric regularization:") |
| for k, v in geo_losses.items(): |
| print(f" {k}: {v.item():.6f}") |
|
|
| |
| diag = model.anchor_diagnostics() |
| print(f"\n Anchor diagnostics:") |
| for layer_name, d in diag.items(): |
| print(f" {layer_name}:") |
| for k, v in d.items(): |
| print(f" {k}: {v:.4f}") |
|
|
| |
| print(f"\n Cayley rotations:") |
| for name, module in model.named_modules(): |
| if isinstance(module, CayleyOrthogonal): |
| R = module.get_rotation() |
| I = torch.eye(R.shape[0], device=R.device) |
| print(f" {name}: βRRα΅-Iβ={((R@R.T)-I).norm():.8f} det={torch.det(R):.4f}") |
|
|
| |
| print(f"\n Gradient flow test:") |
| model.zero_grad() |
| x_grad = torch.randn(B, L, D, device=device, requires_grad=True) |
| out_grad = model(x_grad) |
| loss = out_grad.sum() |
| loss.backward() |
|
|
| |
| for i in range(model.n_layers): |
| layer = model[f'layer_{i}'] |
| gate_grads = [p.grad is not None and p.grad.abs().sum() > 0 |
| for p in layer['cm_gate'].parameters()] |
| print(f" layer_{i} cm_gate grad: {'YES' if all(gate_grads) else 'NO'}") |
|
|
| |
| print(f"\n Training step simulation:") |
| optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) |
| optimizer.zero_grad() |
|
|
| x_train = torch.randn(B, L, D, device=device) |
| out_train, states = model(x_train, return_geo_state=True) |
| task_loss = out_train.mean() |
|
|
| geo_losses = model.geometric_losses() |
| total_loss = task_loss + geo_losses.get('geo_total', 0.0) |
| total_loss.backward() |
| optimizer.step() |
| print(f" task_loss: {task_loss.item():.4f}") |
| print(f" cv_loss: {geo_losses['cv'].item():.6f}") |
| print(f" spread_loss:{geo_losses['spread'].item():.6f}") |
| print(f" total: {total_loss.item():.4f}") |
|
|
| |
| if _HAS_GEOLIP: |
| print(f"\n Paired forward + observer loss:") |
| model.zero_grad() |
|
|
| x1 = torch.randn(B, L, D, device=device) |
| x2 = x1 + 0.1 * torch.randn_like(x1) |
| targets = torch.randint(0, 10, (B,), device=device) |
|
|
| output = model.forward_paired(x1, x2) |
| print(f" Output keys: {sorted(k for k in output if not k.startswith('geo_'))}") |
| for k in ['embedding', 'patchwork1', 'bridge1', 'assign1', 'tri1']: |
| print(f" {k}: {output[k].shape}") |
|
|
| |
| num_classes = 10 |
| head = nn.Linear(D, num_classes).to(device) |
|
|
| loss, ld = model.compute_loss(output, targets, head=head) |
| print(f"\n Three-domain loss breakdown:") |
| for k in ['loss_observer', 'loss_task', 'ce', 'nce_emb', 'nce_pw', |
| 'bridge', 'assign', 'assign_nce', 'nce_tri', 'attract', |
| 'cv', 'spread']: |
| if k in ld: |
| v = ld[k] |
| v = v.item() if isinstance(v, torch.Tensor) else v |
| print(f" {k:16s} = {v:.4f}") |
| for k in ['nce_emb_acc', 'nce_pw_acc', 'nce_tri_acc', 'bridge_acc', |
| 'assign_nce_acc', 'acc']: |
| if k in ld: |
| v = ld[k] |
| v = v if isinstance(v, float) else v.item() |
| print(f" {k:16s} = {v*100:.1f}%") |
| print(f" {'TOTAL':16s} = {loss.item():.4f}") |
|
|
| |
| loss.backward() |
| alive, dead = 0, 0 |
| for n, p in model.named_parameters(): |
| if p.grad is not None and p.grad.norm() > 0: |
| alive += 1 |
| else: |
| dead += 1 |
| print(f"\n Gradient flow: {alive} params alive, {dead} dead") |
|
|
| |
| for i in range(model.n_layers): |
| layer = model[f'layer_{i}'] |
| for comp_name in ['cm_gate', 'observer']: |
| has = any(p.grad is not None and p.grad.norm() > 0 |
| for p in layer[comp_name].parameters()) |
| print(f" layer_{i}.{comp_name}: {'LIVE' if has else 'DEAD'}") |
|
|
| |
| for i in range(model.n_layers): |
| layer = model[f'layer_{i}'] |
| bridge = layer['observer'].curation.bridge |
| has = any(p.grad is not None and p.grad.norm() > 0 |
| for p in bridge.parameters()) |
| print(f" layer_{i}.bridge: {'LIVE' if has else 'DEAD'}") |
| else: |
| print(f"\n [SKIP] forward_paired + compute_loss require geolip_core imports") |
|
|
| print(f"\n{'='*60}") |
| print(f" PASSED β CM-validated pipeline operational") |
| print(f"{'='*60}") |