| """ |
| Geometric Transformer β CM-Validated Pipeline |
| ================================================== |
| Dual-stream transformer with CM-gated constellation observation, |
| quaternion composition, and per-layer Cayley alignment. |
| |
| CM-validated pipeline changes: |
| - CM validity gate between association and curation (AnchorGate) |
| - 5-stream PositionGeometricContext: anchor + structural + history + quality + FLOW |
| - CM-conditioned geometric residual accumulation (replaces blind learned gate) |
| - Built-in geometric regularization (CV target + anchor spread) |
| - Decomposed observer pipeline: association β CM gate β gated curation |
| - Optional FlowEnsemble: multi-opinion geometric fusion (quat, velocity, orbital, etc.) |
| |
| Pipeline per layer: |
| 1. ManifoldProjection: h_i β emb_i on S^(d-1) per position |
| 2. ConstellationAssociation: emb_i β raw triangulation, cos, assignment |
| 3. CMValidatedGate: per-anchor CM validity β gate_values (B*L, A) |
| 4. Gated curation: patchwork reads tri * gate_values (validated only) |
| 4.5 FlowEnsemble (optional): multi-opinion geometric predictions |
| 5. PositionGeometricContext: 5 streams β FiLM context (B, L, context_dim) |
| 6. ContentAttention (Stream A): standard MHA |
| 7. GeometricAttention (Stream B): FiLM(Q,K | geo_ctx), V pure |
| 8. CayleyOrthogonal: align B β A basis |
| 9. QuaternionCompose: w=A, i=aligned_B, j=A-B, k=A*B |
| 10. Decode + gated residual |
| 11. CM-conditioned geometric residual write |
| |
| Geometric regularization (call model.geometric_losses() during training): |
| - CV loss: anchor CV β pentachoron band (0.20-0.23) |
| - Spread loss: prevent anchor collapse (penalize positive cosine) |
| These maintain the constellation in the regime where CM validation works. |
| |
| Design principles from Ryan Spearman (Ο=0.309, 76/84 wins): |
| - FiLM on Q,K ONLY β geometry routes attention, V stays pure |
| - FiLM on individual arms BEFORE composition, not after |
| - Quaternion algebra as structural regularizer (non-commutative coupling) |
| - CayleyOrthogonal guarantees pure rotation (det=1 always) |
| - Never global average pool β per-position geometric context |
| |
| Author: AbstractPhil + Claude Opus 4.6 |
| License: Apache 2.0 |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
|
|
| from geolip_core.core.associate.constellation import ( |
| ConstellationObserver, ConstellationAssociation, ConstellationCuration, |
| Constellation, init_anchors_repulsion, |
| ) |
| from geolip_core.core.curate.gate import AnchorGate as _GeolipAnchorGate |
| from geolip_core.core.curate.flows import FlowEnsemble |
| import geolip_core.linalg as LA |
| from geolip_core.pipeline.observer import ( |
| TorchComponent, BaseTower, Input, Curation, Distinction, |
| ) |
| from geolip_core.core.distinguish.losses import ( |
| observer_loss as _geolip_observer_loss, |
| ce_loss_paired as _geolip_ce_loss_paired, |
| spread_loss as _geolip_spread_loss, |
| ) |
|
|
| |
| try: |
| from geofractal.router.wide_router import WideRouter |
| _HAS_WIDE_ROUTER = True |
| except ImportError: |
| _HAS_WIDE_ROUTER = False |
|
|
| |
| |
| |
|
|
| def pairwise_distances_squared(points): |
| """Batched pairwise squared distances. (B, N, D) β (B, N, N).""" |
| gram = torch.bmm(points, points.transpose(1, 2)) |
| diag = gram.diagonal(dim1=-2, dim2=-1) |
| return diag.unsqueeze(2) + diag.unsqueeze(1) - 2 * gram |
|
|
|
|
| def cayley_menger_det(points): |
| """Cayley-Menger signed volumeΒ² for simplices. (B, K, D) β (B,). |
| |
| K = number of vertices (k+1 for a k-simplex). |
| Sign-corrected: positive = valid non-degenerate simplex. |
| """ |
| B, K, D = points.shape |
| d2 = pairwise_distances_squared(points) |
| M = torch.zeros(B, K + 1, K + 1, device=points.device, dtype=points.dtype) |
| M[:, 0, 1:] = 1.0 |
| M[:, 1:, 0] = 1.0 |
| M[:, 1:, 1:] = d2 |
| raw = LA.det(M) |
| k = K - 1 |
| sign = (-1.0) ** (k + 1) |
| return sign * raw |
|
|
|
|
| def anchor_neighborhood_cm(anchors, n_neighbors=3): |
| """Precompute per-anchor CM quality from local neighborhood geometry. |
| |
| Position-independent. O(A) determinant computations on small matrices. |
| Each anchor forms a simplex with its k nearest neighbor anchors. |
| The CM determinant measures local geometric quality β high volume means |
| the anchor neighborhood is well-conditioned for triangulation. |
| |
| Args: |
| anchors: (A, D) normalized anchor positions on S^(d-1) |
| n_neighbors: neighbors per simplex |
| |
| Returns: |
| quality: (A,) signed log-magnitude CM quality per anchor |
| nn_idx: (A, n_neighbors) neighbor indices |
| """ |
| A, D = anchors.shape |
| dists = torch.cdist(anchors.unsqueeze(0), anchors.unsqueeze(0)).squeeze(0) |
| |
| self_mask = torch.eye(A, device=anchors.device, dtype=anchors.dtype) * 1e12 |
| dists = dists + self_mask |
| _, nn_idx = dists.topk(n_neighbors, largest=False) |
|
|
| |
| simplices = torch.cat([ |
| anchors.unsqueeze(1), |
| anchors[nn_idx], |
| ], dim=1) |
|
|
| dets = cayley_menger_det(simplices) |
| sign = dets.sign() |
| log_mag = torch.log(dets.abs() + 1e-12) |
| return sign * log_mag, nn_idx |
|
|
|
|
| |
| |
| |
|
|
| class CMValidatedGate(nn.Module): |
| """Anchor gate based on Cayley-Menger validity. |
| |
| Efficient for transformer scale: anchor CM quality is precomputed O(AΒ²) |
| and CACHED (only recomputed on invalidate_cache()), then combined with |
| per-position proximity features through a learned gate. |
| |
| The gate starts OPEN (bias=+2, sigmoidβ0.88) and learns to CLOSE on |
| geometrically invalid configurations. Architecture-before-loss: the gate |
| suppresses degenerate measurements structurally, not through a loss signal. |
| |
| Gate features per (position, anchor): |
| - anchor_cm_quality: CM volume of anchor's local neighborhood (cached) |
| - cos_to_anchor: cosine similarity (position-dependent) |
| |
| Args: |
| n_anchors: number of constellation anchors |
| n_neighbors: neighbors for CM simplex computation |
| """ |
| def __init__(self, n_anchors, n_neighbors=3): |
| super().__init__() |
| self.n_anchors = n_anchors |
| self.n_neighbors = n_neighbors |
|
|
| |
| self.gate_proj = nn.Sequential( |
| nn.Linear(2, 16), |
| nn.GELU(), |
| nn.Linear(16, 1), |
| ) |
| |
| |
| nn.init.normal_(self.gate_proj[2].weight, std=0.01) |
| nn.init.constant_(self.gate_proj[2].bias, 2.0) |
|
|
| |
| self._cached_cm_norm = None |
|
|
| def invalidate_cache(self): |
| """Call after optimizer.step() to recompute anchor CM next forward.""" |
| self._cached_cm_norm = None |
|
|
| def precompute(self, anchors): |
| """Compute anchor CM norm OUTSIDE compile graph. |
| Called from layer forward before the compilable gate computation. |
| Idempotent: skips if cache is warm. |
| """ |
| if self._cached_cm_norm is not None: |
| return |
| with torch.no_grad(): |
| anchor_cm, _ = anchor_neighborhood_cm(anchors, self.n_neighbors) |
| cm_std = anchor_cm.std().clamp(min=1e-8) |
| self._cached_cm_norm = ((anchor_cm - anchor_cm.mean()) / cm_std).detach() |
|
|
| def _compute_gate(self, anchor_cm_norm, tri): |
| """Fully compilable β pure tensor ops, no linalg, no graph breaks.""" |
| N, A = tri.shape |
| cos_sim = 1.0 - tri |
|
|
| features = torch.stack([ |
| anchor_cm_norm.unsqueeze(0).expand(N, -1), |
| cos_sim, |
| ], dim=-1) |
|
|
| gate_values = torch.sigmoid(self.gate_proj(features).squeeze(-1)) |
|
|
| gate_info = { |
| 'active': (gate_values.detach() > 0.5).float().sum(-1).mean(), |
| 'gate_mean': gate_values.detach().mean(), |
| 'cm_positive_frac': (anchor_cm_norm > 0).float().mean(), |
| } |
|
|
| return gate_values, gate_info |
|
|
| def forward(self, tri): |
| """Fully compilable forward. Requires precompute() called first.""" |
| return self._compute_gate(self._cached_cm_norm, tri) |
|
|
|
|
| |
| |
| |
|
|
| class GeoResidualBank(nn.Module): |
| """Cross-stream contrastive memory bank (CLIP-style).""" |
| def __init__(self, proj_dim, bank_size=4096, temperature=0.1): |
| super().__init__() |
| self.proj_dim = proj_dim |
| self.bank_size = bank_size |
| self.temperature = temperature |
|
|
| self.register_buffer('queue', torch.randn(bank_size, proj_dim)) |
| self.queue = F.normalize(self.queue, dim=-1) |
| self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long)) |
|
|
| @torch.no_grad() |
| def enqueue(self, keys): |
| B = keys.shape[0] |
| ptr = int(self.queue_ptr.item()) |
| if ptr + B <= self.bank_size: |
| self.queue[ptr:ptr + B] = keys |
| else: |
| overflow = (ptr + B) - self.bank_size |
| self.queue[ptr:] = keys[:B - overflow] |
| self.queue[:overflow] = keys[B - overflow:] |
| self.queue_ptr[0] = (ptr + B) % self.bank_size |
|
|
| def forward(self, content_proj, geo_proj): |
| q = F.normalize(content_proj, dim=-1) |
| k_pos = F.normalize(geo_proj, dim=-1) |
| k_neg = self.queue.clone().detach() |
|
|
| pos_logits = (q * k_pos).sum(dim=-1, keepdim=True) / self.temperature |
| neg_logits = q @ k_neg.T / self.temperature |
|
|
| logits = torch.cat([pos_logits, neg_logits], dim=1) |
| labels = torch.zeros(q.shape[0], dtype=torch.long, device=q.device) |
|
|
| loss = F.cross_entropy(logits, labels) |
|
|
| with torch.no_grad(): |
| acc = (logits.argmax(dim=1) == 0).float().mean() |
|
|
| return loss, acc |
|
|
|
|
| |
| |
| |
|
|
| class FiLMLayer(TorchComponent): |
| """Feature-wise Linear Modulation. Near-identity-initialized. |
| gamma β 1 + 0.01Β·geo_ctx, beta β 0.01Β·geo_ctx at init. |
| Gradient flows through to geo_ctx from step 0. |
| """ |
| def __init__(self, name, feature_dim, context_dim): |
| super().__init__(name) |
| self.to_gamma = nn.Linear(context_dim, feature_dim) |
| self.to_beta = nn.Linear(context_dim, feature_dim) |
| nn.init.normal_(self.to_gamma.weight, std=0.01); nn.init.ones_(self.to_gamma.bias) |
| nn.init.normal_(self.to_beta.weight, std=0.01); nn.init.zeros_(self.to_beta.bias) |
|
|
| def forward(self, x, ctx): |
| return self.to_gamma(ctx) * x + self.to_beta(ctx) |
|
|
|
|
| class CayleyOrthogonal(TorchComponent): |
| """Guaranteed SO(d) rotation via Cayley map. det(Q) = 1 always.""" |
| def __init__(self, name, dim): |
| super().__init__(name) |
| self.dim = dim |
| self.A_upper = nn.Parameter(torch.zeros(dim * (dim - 1) // 2) * 0.01) |
| idx = torch.triu_indices(dim, dim, offset=1) |
| self.register_buffer('_triu_row', idx[0], persistent=False) |
| self.register_buffer('_triu_col', idx[1], persistent=False) |
| self.register_buffer('_eye', torch.eye(dim), persistent=False) |
|
|
| def get_rotation(self): |
| d = self.dim |
| A = torch.zeros(d, d, device=self.A_upper.device, dtype=self.A_upper.dtype) |
| A[self._triu_row, self._triu_col] = self.A_upper |
| A = A - A.T |
| return LA.solve(self._eye + A, self._eye - A) |
|
|
| def forward(self, x): |
| return x @ self.get_rotation().T |
|
|
|
|
| def quaternion_multiply_batched(q1, q2): |
| """Hamilton product on (B, 4, D) tensors. Fully vectorized.""" |
| w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3] |
| w2, x2, y2, z2 = q2[:, 0], q2[:, 1], q2[:, 2], q2[:, 3] |
| return torch.stack([ |
| w1*w2 - x1*x2 - y1*y2 - z1*z2, |
| w1*x2 + x1*w2 + y1*z2 - z1*y2, |
| w1*y2 - x1*z2 + y1*w2 + z1*x2, |
| w1*z2 + x1*y2 - y1*x2 + z1*w2, |
| ], dim=1) |
|
|
|
|
| class QuaternionCompose(TorchComponent): |
| """Four-arm Hamilton product composition. Proven in GeoQuat head.""" |
| def __init__(self, name, input_dim, quat_dim=64): |
| super().__init__(name) |
| self.quat_dim = quat_dim |
| self.proj_w = nn.Linear(input_dim, quat_dim) |
| self.proj_i = nn.Linear(input_dim, quat_dim) |
| self.proj_j = nn.Linear(input_dim, quat_dim) |
| self.proj_k = nn.Linear(input_dim, quat_dim) |
| self.rotation = nn.Parameter(torch.randn(1, 4, quat_dim) * 0.1) |
|
|
| @property |
| def output_dim(self): |
| return self.quat_dim * 4 |
|
|
| def forward(self, arm_w, arm_i, arm_j, arm_k): |
| shape = arm_w.shape[:-1] |
| D = arm_w.shape[-1] |
| flat = arm_w.dim() > 2 |
| if flat: |
| arm_w = arm_w.reshape(-1, D); arm_i = arm_i.reshape(-1, D) |
| arm_j = arm_j.reshape(-1, D); arm_k = arm_k.reshape(-1, D) |
| q = torch.stack([self.proj_w(arm_w), self.proj_i(arm_i), |
| self.proj_j(arm_j), self.proj_k(arm_k)], dim=1) |
| q = q / (q.norm(dim=1, keepdim=True) + 1e-8) |
| r = self.rotation.expand(q.shape[0], -1, -1) |
| r = r / (r.norm(dim=1, keepdim=True) + 1e-8) |
| composed = quaternion_multiply_batched(r, q) |
| composed = composed.reshape(q.shape[0], -1) |
| if flat: |
| composed = composed.reshape(*shape, -1) |
| return composed |
|
|
|
|
| |
| |
| |
|
|
| class ManifoldProjection(TorchComponent): |
| """Input stage: project transformer hidden states to S^(d-1).""" |
| def __init__(self, name, d_model, manifold_dim): |
| super().__init__(name) |
| self.proj = nn.Linear(d_model, manifold_dim) |
| self.norm = nn.LayerNorm(manifold_dim) |
|
|
| def forward(self, hidden_states): |
| h = self.norm(self.proj(hidden_states)) |
| return F.normalize(h, dim=-1) |
|
|
|
|
| class PositionGeometricContext(TorchComponent): |
| """Curation stage: 5-stream fusion β FiLM context. |
| |
| Five streams: |
| anchor: cos_to_anchors + assignment + triangulation β WHERE on the manifold |
| structural: patchwork + embedding β WHAT the local geometry looks like |
| history: geo_residual from previous layers β WHAT prior layers observed |
| quality: CM gate values per anchor β HOW TRUSTWORTHY is this observation |
| flow: FlowEnsemble predictions β WHAT other mathematical lenses see |
| |
| The flow stream starts at zero (zero-init) and learns to contribute. |
| Without flows attached, the 5th stream is zeros β equivalent to the |
| original 4-stream architecture. |
| """ |
| def __init__(self, name, n_anchors, pw_dim, manifold_dim, context_dim): |
| super().__init__(name) |
| self.context_dim = context_dim |
| self.pw_dim = pw_dim |
| self.manifold_dim = manifold_dim |
|
|
| |
| self.anchor_mlp = nn.Sequential( |
| nn.Linear(n_anchors * 3, context_dim), nn.GELU(), nn.LayerNorm(context_dim)) |
| |
| self.struct_mlp = nn.Sequential( |
| nn.Linear(pw_dim + manifold_dim, context_dim), nn.GELU(), nn.LayerNorm(context_dim)) |
| |
| self.history_mlp = nn.Sequential( |
| nn.Linear(pw_dim, context_dim), nn.GELU(), nn.LayerNorm(context_dim)) |
| |
| self.quality_mlp = nn.Sequential( |
| nn.Linear(n_anchors, context_dim), nn.GELU(), nn.LayerNorm(context_dim)) |
| |
| |
| self.flow_mlp = nn.Sequential( |
| nn.Linear(n_anchors, context_dim), nn.GELU(), nn.LayerNorm(context_dim)) |
| nn.init.normal_(self.flow_mlp[0].weight, std=0.01) |
| nn.init.zeros_(self.flow_mlp[0].bias) |
|
|
| |
| self.fuse = nn.Sequential( |
| nn.Linear(context_dim * 5, context_dim), nn.GELU(), nn.LayerNorm(context_dim)) |
|
|
| def forward(self, obs_dict, gate_values=None, geo_residual=None, flow_output=None): |
| """ |
| Args: |
| obs_dict: from decomposed association + gated curation |
| gate_values: (N, A) CM gate values per anchor, or None |
| geo_residual: (N, pw_dim) accumulated context, or None for first layer |
| flow_output: (N, manifold_dim) flow ensemble prediction, or None |
| Returns: |
| (N, context_dim) geometric context for FiLM |
| """ |
| anchor_feats = torch.cat([ |
| obs_dict['cos_to_anchors'], |
| obs_dict['assignment'], |
| obs_dict['triangulation'], |
| ], dim=-1) |
| struct_feats = torch.cat([ |
| obs_dict['patchwork'], |
| obs_dict['embedding'], |
| ], dim=-1) |
|
|
| a = self.anchor_mlp(anchor_feats) |
| s = self.struct_mlp(struct_feats) |
| h = self.history_mlp(geo_residual) if geo_residual is not None else torch.zeros_like(a) |
| q = self.quality_mlp(gate_values) if gate_values is not None else torch.zeros_like(a) |
| f = self.flow_mlp(flow_output) if flow_output is not None else torch.zeros_like(a) |
|
|
| return self.fuse(torch.cat([a, s, h, q, f], dim=-1)) |
|
|
|
|
| class GeometricAttention(TorchComponent): |
| """Attention with FiLM from curated constellation. Stream B.""" |
| def __init__(self, name, d_model, n_heads=8, context_dim=128, dropout=0.1): |
| super().__init__(name) |
| self.d_model = d_model |
| self.n_heads = n_heads |
| self.head_dim = d_model // n_heads |
| self.scale = self.head_dim ** -0.5 |
|
|
| self.w_q = nn.Linear(d_model, d_model) |
| self.w_k = nn.Linear(d_model, d_model) |
| self.w_v = nn.Linear(d_model, d_model) |
| self.w_o = nn.Linear(d_model, d_model) |
| self.dropout = nn.Dropout(dropout) |
|
|
| self.film_q = FiLMLayer(f'{name}_film_q', d_model, context_dim) |
| self.film_k = FiLMLayer(f'{name}_film_k', d_model, context_dim) |
| self.norm = nn.LayerNorm(d_model) |
|
|
| self.ffn1 = nn.Linear(d_model, d_model * 4) |
| self.film_ffn = FiLMLayer(f'{name}_film_ffn', d_model * 4, context_dim) |
| self.ffn2 = nn.Linear(d_model * 4, d_model) |
| self.ffn_drop = nn.Dropout(dropout) |
| self.ffn_norm = nn.LayerNorm(d_model) |
|
|
| def forward(self, x, geo_ctx, attn_mask=None, key_padding_mask=None): |
| B, L, D = x.shape |
| H, HD = self.n_heads, self.head_dim |
|
|
| Q = self.film_q(self.w_q(x), geo_ctx) |
| K = self.film_k(self.w_k(x), geo_ctx) |
| V = self.w_v(x) |
|
|
| Q = Q.view(B, L, H, HD).transpose(1, 2) |
| K = K.view(B, L, H, HD).transpose(1, 2) |
| V = V.view(B, L, H, HD).transpose(1, 2) |
|
|
| scores = (Q @ K.transpose(-2, -1)) * self.scale |
| if attn_mask is not None: |
| scores = scores + attn_mask |
| if key_padding_mask is not None: |
| scores = scores.masked_fill( |
| key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf')) |
| attn_out = (self.dropout(F.softmax(scores, dim=-1)) @ V) |
| attn_out = attn_out.transpose(1, 2).reshape(B, L, D) |
| x = self.norm(x + self.w_o(attn_out)) |
|
|
| h = F.gelu(self.ffn1(x)) |
| h = self.film_ffn(h, geo_ctx) |
| x = self.ffn_norm(x + self.ffn_drop(self.ffn2(h))) |
| return x |
|
|
|
|
| class ContentAttention(TorchComponent): |
| """Standard self-attention. Stream A. No geometric conditioning.""" |
| def __init__(self, name, d_model, n_heads=8, dropout=0.1): |
| super().__init__(name) |
| self.attn = nn.MultiheadAttention( |
| d_model, n_heads, dropout=dropout, batch_first=True) |
| self.norm = nn.LayerNorm(d_model) |
| self.ffn = nn.Sequential( |
| nn.Linear(d_model, d_model * 4), nn.GELU(), |
| nn.Linear(d_model * 4, d_model), nn.Dropout(dropout)) |
| self.ffn_norm = nn.LayerNorm(d_model) |
|
|
| def forward(self, x, attn_mask=None, key_padding_mask=None): |
| a, _ = self.attn(x, x, x, attn_mask=attn_mask, |
| key_padding_mask=key_padding_mask, |
| need_weights=False) |
| x = self.norm(x + a) |
| x = self.ffn_norm(x + self.ffn(x)) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class GeometricTransformerLayer(BaseTower): |
| """One layer of the geometric transformer (CM validated + flows). |
| |
| Pipeline per layer: |
| 1. ManifoldProjection: h β emb on S^(d-1) |
| 2. Association: emb β raw triangulation, cos, assignment |
| 3. CMValidatedGate: per-anchor CM validity β gate_values |
| 4. Gated curation: patchwork reads tri * gate_values |
| 4.5 FlowEnsemble (optional): multi-opinion geometric predictions |
| 5. PositionGeometricContext: 5 streams β FiLM context |
| 6. ContentAttention (Stream A): standard MHA |
| 7. GeometricAttention (Stream B): FiLM(Q,K | geo_ctx) |
| 8. CayleyOrthogonal: align B β A |
| 9. QuaternionCompose: w=A, i=aligned_B, j=A-B, k=A*B |
| 10. Decode + gated residual |
| 11. CM-conditioned geometric residual accumulation |
| |
| Flows are optional, config-driven, and individually replaceable: |
| layer['flows'].attach_flow('alignment') |
| layer['flows'].detach_flow('velocity') |
| """ |
| def __init__(self, name, d_model, n_heads=8, n_anchors=32, |
| manifold_dim=256, n_comp=8, d_comp=32, |
| context_dim=128, quat_dim=64, dropout=0.1, |
| cm_neighbors=3, flow_keys=None, flow_fusion='weighted'): |
| super().__init__(name) |
| self.d_model = d_model |
| self.n_anchors = n_anchors |
| self.manifold_dim = manifold_dim |
|
|
| |
| self.attach('projection', ManifoldProjection( |
| f'{name}_proj', d_model, manifold_dim)) |
|
|
| |
| self.attach('observer', ConstellationObserver( |
| dim=manifold_dim, n_anchors=n_anchors, |
| n_comp=n_comp, d_comp=d_comp)) |
|
|
| |
| self.attach('cm_gate', CMValidatedGate( |
| n_anchors=n_anchors, n_neighbors=cm_neighbors)) |
|
|
| |
| if flow_keys: |
| self.attach('flows', FlowEnsemble( |
| f'{name}_flows', manifold_dim, n_anchors, |
| flow_keys=flow_keys, fusion=flow_fusion)) |
| |
| |
| self.flow_alpha = nn.Parameter(torch.tensor(0.01)) |
|
|
| |
| pw_dim = self['observer'].curation.patchwork.output_dim |
| self.attach('context', PositionGeometricContext( |
| f'{name}_ctx', n_anchors, pw_dim, manifold_dim, context_dim)) |
|
|
| |
| self.attach('content', ContentAttention( |
| f'{name}_content', d_model, n_heads, dropout)) |
|
|
| |
| self.attach('geometric', GeometricAttention( |
| f'{name}_geo', d_model, n_heads, context_dim, dropout)) |
|
|
| |
| self.attach('rotation', CayleyOrthogonal(f'{name}_cayley', d_model)) |
|
|
| |
| self.attach('compose', QuaternionCompose( |
| f'{name}_quat', d_model, quat_dim)) |
|
|
| |
| self.attach('decode', nn.Sequential( |
| nn.Linear(quat_dim * 4, d_model), nn.GELU(), nn.LayerNorm(d_model))) |
| self.attach('gate', nn.Sequential( |
| nn.Linear(d_model * 2, d_model), nn.Sigmoid())) |
|
|
| |
| self._pw_dim = pw_dim |
| self.attach('geo_proj', nn.Sequential( |
| nn.Linear(pw_dim, pw_dim), nn.LayerNorm(pw_dim))) |
|
|
| def forward(self, x, geo_residual=None, attn_mask=None, key_padding_mask=None): |
| """ |
| Args: |
| x: (B, L, D) input hidden states |
| geo_residual: (B, L, pw_dim) accumulated geometric context, |
| or None for first layer |
| |
| Returns: |
| x_out: (B, L, D) transformed hidden states |
| geo_residual_out: (B, L, pw_dim) updated geometric residual |
| geo_state: dict with full geometric state + CM diagnostics |
| """ |
| B, L, D = x.shape |
|
|
| |
| emb = self['projection'](x) |
| emb_flat = emb.reshape(B * L, -1) |
|
|
| |
| a_out = self['observer'].association(emb_flat) |
|
|
| |
| anchors_n = F.normalize( |
| self['observer'].association.constellation.anchors, dim=-1) |
| |
| |
| gate_values, gate_info = self['cm_gate'](a_out['distances']) |
|
|
| |
| a_out_gated = dict(a_out) |
|
|
| |
| flow_opinion = None |
| if self.has('flows'): |
| flow_opinion = self['flows'](anchors_n, emb_flat, a_out['distances']) |
| |
| |
| |
| alpha = self.flow_alpha.sigmoid() |
| blended_tri = a_out['distances'] + alpha * (flow_opinion - a_out['distances']) |
| a_out_gated['distances_weighted'] = blended_tri * gate_values |
| else: |
| a_out_gated['distances_weighted'] = a_out['distances'] * gate_values |
|
|
| c_out = self['observer'].curation.curate_full(a_out_gated, emb=emb_flat) |
|
|
| |
| obs = { |
| 'embedding': emb_flat, |
| 'triangulation': a_out['distances'], |
| 'cos_to_anchors': a_out['cos_to_anchors'], |
| 'assignment': a_out['assignment'], |
| 'nearest': a_out['nearest'], |
| 'patchwork': c_out['patchwork'], |
| 'bridge': c_out['bridge'], |
| } |
|
|
| |
| geo_res_flat = geo_residual.reshape(B * L, -1) if geo_residual is not None else None |
| geo_ctx_flat = self['context']( |
| obs, gate_values=gate_values, geo_residual=geo_res_flat, |
| flow_output=flow_opinion) |
| geo_ctx = geo_ctx_flat.reshape(B, L, -1) |
|
|
| |
| a_out_stream = self['content']( |
| x, attn_mask=attn_mask, key_padding_mask=key_padding_mask) |
|
|
| |
| b_out = self['geometric']( |
| x, geo_ctx, attn_mask=attn_mask, key_padding_mask=key_padding_mask) |
|
|
| |
| b_aligned = self['rotation'](b_out) |
|
|
| |
| composed = self['compose']( |
| arm_w=a_out_stream, arm_i=b_aligned, |
| arm_j=a_out_stream - b_aligned, arm_k=a_out_stream * b_aligned) |
|
|
| |
| decoded = self['decode'](composed) |
| g = self['gate'](torch.cat([x, decoded], dim=-1)) |
| x_out = g * decoded + (1 - g) * x |
|
|
| |
| pw_validated = c_out['patchwork'].reshape(B, L, -1) |
| cm_quality = gate_values.mean(dim=-1).reshape(B, L, 1) |
| geo_update = self['geo_proj'](pw_validated) |
|
|
| if geo_residual is None: |
| geo_residual_out = cm_quality * geo_update |
| else: |
| geo_residual_out = geo_residual + cm_quality * geo_update |
|
|
| |
| def _unflatten(t): |
| if t is None: |
| return None |
| if t.dim() == 1: |
| return t.reshape(B, L) |
| return t.reshape(B, L, *t.shape[1:]) |
|
|
| geo_state = { |
| 'embedding': emb, |
| 'geo_ctx': geo_ctx, |
| 'triangulation': _unflatten(a_out['distances']), |
| 'cos_to_anchors': _unflatten(a_out['cos_to_anchors']), |
| 'assignment': _unflatten(a_out['assignment']), |
| 'nearest': _unflatten(a_out['nearest']), |
| 'patchwork': _unflatten(c_out['patchwork']), |
| 'bridge': _unflatten(c_out['bridge']), |
| 'gate_values': _unflatten(gate_values), |
| 'gate_info': gate_info, |
| 'cm_quality': cm_quality, |
| 'content': a_out_stream, |
| 'geometric': b_out, |
| 'composed': composed, |
| 'geo_residual': geo_residual_out, |
| 'flow_opinion': _unflatten(flow_opinion) if flow_opinion is not None else None, |
| } |
|
|
| return x_out, geo_residual_out, geo_state |
|
|
|
|
| |
| |
| |
|
|
| class GeometricTransformer(BaseTower): |
| """Geometric Transformer β CM-validated dual-stream with optional flows. |
| |
| Stack of GeometricTransformerLayers with: |
| - CM-gated observation at every layer |
| - Optional FlowEnsemble at every layer (config-driven) |
| - Cross-layer Cayley rotation on hidden states |
| - Built-in geometric regularization via geometric_losses() |
| """ |
| def __init__(self, name, d_model=512, n_heads=8, n_layers=4, |
| n_anchors=32, manifold_dim=256, n_comp=8, d_comp=32, |
| context_dim=128, quat_dim=64, dropout=0.1, |
| cross_layer_rotation=True, cm_neighbors=3, |
| nce_bank_size=4096, nce_temperature=0.1, |
| vocab_size=None, max_seq_len=2048, |
| flow_keys=None, flow_fusion='weighted'): |
| super().__init__(name) |
| self.d_model = d_model |
| self.n_layers = n_layers |
| self.n_anchors = n_anchors |
| self._pw_dim = n_comp * d_comp |
|
|
| if vocab_size is not None: |
| self.attach('embed', nn.Embedding(vocab_size, d_model)) |
| self.attach('pos_embed', nn.Embedding(max_seq_len, d_model)) |
| self.attach('head', nn.Linear(d_model, vocab_size, bias=False)) |
|
|
| for i in range(n_layers): |
| self.attach(f'layer_{i}', GeometricTransformerLayer( |
| f'{name}_L{i}', d_model, n_heads, n_anchors, |
| manifold_dim, n_comp, d_comp, context_dim, quat_dim, |
| dropout, cm_neighbors, |
| flow_keys=flow_keys, flow_fusion=flow_fusion)) |
|
|
| if cross_layer_rotation and n_layers > 1: |
| for i in range(n_layers - 1): |
| self.attach(f'cross_rot_{i}', CayleyOrthogonal( |
| f'{name}_xrot_{i}', d_model)) |
|
|
| self.attach('final_norm', nn.LayerNorm(d_model)) |
|
|
| |
| if nce_bank_size > 0: |
| nce_proj_dim = 128 |
| self.attach('nce_content_proj', nn.Sequential( |
| nn.Linear(d_model, nce_proj_dim), |
| nn.GELU(), |
| nn.Linear(nce_proj_dim, nce_proj_dim), |
| )) |
| self.attach('nce_geo_proj', nn.Sequential( |
| nn.Linear(self._pw_dim, nce_proj_dim), |
| nn.GELU(), |
| nn.Linear(nce_proj_dim, nce_proj_dim), |
| )) |
| self.attach('nce_bank', GeoResidualBank( |
| nce_proj_dim, bank_size=nce_bank_size, |
| temperature=nce_temperature)) |
|
|
| self._config = dict( |
| d_model=d_model, n_heads=n_heads, n_layers=n_layers, |
| n_anchors=n_anchors, manifold_dim=manifold_dim, |
| n_comp=n_comp, d_comp=d_comp, context_dim=context_dim, |
| quat_dim=quat_dim, dropout=dropout, |
| cross_layer_rotation=cross_layer_rotation, |
| cm_neighbors=cm_neighbors, vocab_size=vocab_size, |
| nce_bank_size=nce_bank_size, nce_temperature=nce_temperature, |
| flow_keys=flow_keys, flow_fusion=flow_fusion, |
| ) |
|
|
| @property |
| def config(self): |
| return self._config.copy() |
|
|
| def invalidate_caches(self): |
| """Invalidate all CM gate caches. Call after optimizer.step().""" |
| for i in range(self.n_layers): |
| self[f'layer_{i}']['cm_gate'].invalidate_cache() |
|
|
| @torch.compiler.disable |
| def precompute_cm_gates(self): |
| """Precompute CM gate anchor quality for all layers. |
| |
| Must be called BEFORE the compiled forward pass. CUDA graph |
| capture cannot contain module attribute mutations (precompute |
| writes to self._cached_cm_norm). This runs outside the graph. |
| |
| Idempotent: skips layers with warm caches. |
| """ |
| for i in range(self.n_layers): |
| layer = self[f'layer_{i}'] |
| anchors_n = F.normalize( |
| layer['observer'].association.constellation.anchors, dim=-1) |
| layer['cm_gate'].precompute(anchors_n.detach()) |
|
|
| def geometric_losses(self, cv_target=0.215, cv_weight=0.1, spread_weight=0.01): |
| """Compute geometric regularization from current anchor geometry.""" |
| total_cv = torch.tensor(0.0) |
| total_spread = torch.tensor(0.0) |
| n = 0 |
|
|
| for i in range(self.n_layers): |
| layer = self[f'layer_{i}'] |
| anchors = layer['observer'].association.constellation.anchors |
| anchors_n = F.normalize(anchors, dim=-1) |
| A = anchors_n.shape[0] |
|
|
| if n == 0: |
| total_cv = total_cv.to(anchors.device) |
| total_spread = total_spread.to(anchors.device) |
|
|
| cos = anchors_n @ anchors_n.T |
| idx = torch.triu_indices(A, A, offset=1, device=cos.device) |
| pairwise_dist = 1.0 - cos[idx[0], idx[1]] |
| cv = pairwise_dist.std() / (pairwise_dist.mean() + 1e-8) |
| total_cv = total_cv + (cv - cv_target).pow(2) |
|
|
| mask = ~torch.eye(A, dtype=torch.bool, device=cos.device) |
| total_spread = total_spread + F.relu(cos[mask]).mean() |
|
|
| n += 1 |
|
|
| losses = {} |
| if n > 0: |
| losses['cv'] = cv_weight * total_cv / n |
| losses['spread'] = spread_weight * total_spread / n |
| losses['geo_total'] = losses['cv'] + losses['spread'] |
| return losses |
|
|
| def infonce_loss(self, cls_index=0): |
| """Cross-stream contrastive: content queries against decoupled geometry.""" |
| if not self.has('nce_bank'): |
| return {} |
|
|
| hidden = getattr(self, '_last_hidden', None) |
| geo_residual = getattr(self, '_last_geo_residual', None) |
| if hidden is None or geo_residual is None: |
| return {} |
|
|
| content_cls = self['nce_content_proj'](hidden[:, cls_index]) |
| geo_cls = self['nce_geo_proj'](geo_residual[:, cls_index].detach()) |
|
|
| loss, acc = self['nce_bank'](content_cls, geo_cls) |
| return {'nce': loss, 'nce_acc': acc} |
|
|
| @torch.no_grad() |
| def update_nce_bank(self, cls_index=0): |
| """Enqueue projected geo keys into bank. Call AFTER backward.""" |
| if not self.has('nce_bank') or not self.has('nce_geo_proj'): |
| return |
|
|
| geo_residual = getattr(self, '_last_geo_residual', None) |
| if geo_residual is None: |
| return |
|
|
| geo_cls = self['nce_geo_proj'](geo_residual[:, cls_index].detach()) |
| self['nce_bank'].enqueue(F.normalize(geo_cls, dim=-1)) |
|
|
| def anchor_diagnostics(self): |
| """Per-layer anchor health diagnostics.""" |
| diag = {} |
| for i in range(self.n_layers): |
| layer = self[f'layer_{i}'] |
| anchors = layer['observer'].association.constellation.anchors |
| anchors_n = F.normalize(anchors.detach(), dim=-1) |
| A = anchors_n.shape[0] |
|
|
| cos = anchors_n @ anchors_n.T |
| idx = torch.triu_indices(A, A, offset=1, device=cos.device) |
| pairwise = 1.0 - cos[idx[0], idx[1]] |
| cv = (pairwise.std() / (pairwise.mean() + 1e-8)).item() |
|
|
| with torch.no_grad(): |
| anchor_cm, _ = anchor_neighborhood_cm( |
| anchors_n, layer['cm_gate'].n_neighbors) |
|
|
| diag[f'layer_{i}'] = { |
| 'anchor_cv': cv, |
| 'mean_pairwise_dist': pairwise.mean().item(), |
| 'min_pairwise_dist': pairwise.min().item(), |
| 'cm_positive_frac': (anchor_cm > 0).float().mean().item(), |
| 'cm_mean': anchor_cm.mean().item(), |
| 'cm_std': anchor_cm.std().item(), |
| } |
| return diag |
|
|
| def param_report(self): |
| total = 0 |
| name = getattr(self, '_tower_name', self.__class__.__name__) |
| print(f"\n {name} β parameter report (CM-validated + flows)") |
| print(f" {'Component':<35s} {'Params':>12s}") |
| print(f" {'β'*35} {'β'*12}") |
| for cname, module in self.named_children(): |
| n = sum(p.numel() for p in module.parameters()) |
| total += n |
| print(f" {cname:<35s} {n:>12,}") |
| print(f" {'β'*35} {'β'*12}") |
| print(f" {'TOTAL':<35s} {total:>12,}") |
| return total |
|
|
|
|
| def forward(self, x, attn_mask=None, key_padding_mask=None, |
| return_geo_state=False): |
| if self.has('embed') and x.dtype in (torch.long, torch.int32, torch.int64): |
| pos = torch.arange(x.shape[1], device=x.device) |
| x = self['embed'](x) + self['pos_embed'](pos) |
|
|
| geo_states = [] |
| has_xrot = self.has('cross_rot_0') |
| geo_residual = None |
|
|
| for i in range(self.n_layers): |
| x, geo_residual, geo_state = self[f'layer_{i}']( |
| x, geo_residual=geo_residual, |
| attn_mask=attn_mask, key_padding_mask=key_padding_mask) |
| if return_geo_state: |
| geo_states.append(geo_state) |
| if has_xrot and i < self.n_layers - 1: |
| x = self[f'cross_rot_{i}'](x) |
|
|
| self._last_geo_residual = geo_residual |
| self._last_hidden = x |
|
|
| x = self['final_norm'](x) |
| if self.has('head'): |
| x = self['head'](x) |
|
|
| return (x, geo_states) if return_geo_state else x |
|
|
| |
|
|
| def _run_view(self, x, attn_mask=None, key_padding_mask=None): |
| """Run one view through the full pipeline. |
| Retains ALL layers' geo_states β every layer needs gradient. |
| """ |
| has_xrot = self.has('cross_rot_0') |
| geo_residual = None |
|
|
| if self.has('embed') and x.dtype in (torch.long, torch.int32, torch.int64): |
| pos = torch.arange(x.shape[1], device=x.device) |
| x = self['embed'](x) + self['pos_embed'](pos) |
|
|
| geo_states = [] |
| for i in range(self.n_layers): |
| x, geo_residual, geo_state = self[f'layer_{i}']( |
| x, geo_residual=geo_residual, |
| attn_mask=attn_mask, key_padding_mask=key_padding_mask) |
| geo_states.append(geo_state) |
| if has_xrot and i < self.n_layers - 1: |
| x = self[f'cross_rot_{i}'](x) |
|
|
| x = self['final_norm'](x) |
| return x, geo_states |
|
|
| def forward_paired(self, x1, x2, cls_index=0, |
| attn_mask=None, key_padding_mask=None): |
| """Dual-view forward for observer loss training. |
| |
| Observer loss reads FINAL layer's observations (coherent space). |
| Non-final layers get gradient through the geo_residual stream |
| (FiLM β context β history_mlp β geo_residual β earlier layers). |
| All layers' computation graphs are retained by _run_view. |
| """ |
| B = x1.shape[0] |
| x_cat = torch.cat([x1, x2], dim=0) |
| feat_cat, geo_states = self._run_view(x_cat, attn_mask, key_padding_mask) |
|
|
| c = cls_index |
| gs = geo_states[-1] |
|
|
| return { |
| 'embedding': gs['embedding'][:B, c], |
| 'embedding_aug': gs['embedding'][B:, c], |
| 'patchwork1': gs['patchwork'][:B, c], |
| 'patchwork1_aug': gs['patchwork'][B:, c], |
| 'bridge1': gs['bridge'][:B, c], |
| 'bridge2': gs['bridge'][B:, c], |
| 'assign1': gs['assignment'][:B, c], |
| 'assign2': gs['assignment'][B:, c], |
| 'cos1': gs['cos_to_anchors'][:B, c], |
| 'tri1': gs['triangulation'][:B, c], |
| 'tri2': gs['triangulation'][B:, c], |
| 'features1': feat_cat[:B], |
| 'features2': feat_cat[B:], |
| 'gate_values1': gs['gate_values'][:B, c], |
| 'gate_values2': gs['gate_values'][B:, c], |
| 'cm_quality1': gs['cm_quality'][:B], |
| 'cm_quality2': gs['cm_quality'][B:], |
| } |
|
|
| def compute_loss(self, output, targets, cls_index=0, |
| w_ce=1.0, head=None, **loss_kwargs): |
| final_layer = self[f'layer_{self.n_layers - 1}'] |
| anchors = final_layer['observer'].association.constellation.anchors |
|
|
| obs_loss, ld = _geolip_observer_loss( |
| output, anchors=anchors, targets=targets, |
| **loss_kwargs) |
|
|
| if head is not None: |
| feat1 = output['features1'][:, cls_index] |
| feat2 = output['features2'][:, cls_index] |
| logits1 = head(feat1) |
| logits2 = head(feat2) |
| l_ce, acc = _geolip_ce_loss_paired(logits1, logits2, targets) |
| ld['ce'], ld['acc'] = l_ce, acc |
| ld['logits'] = logits1 |
| loss = w_ce * l_ce + obs_loss |
| ld['loss_task'] = l_ce.detach() |
| else: |
| loss = obs_loss |
|
|
| ld['loss_observer'] = obs_loss.detach() |
|
|
| w_spread = loss_kwargs.get('w_spread', 0.01) |
| if self.n_layers > 1 and w_spread > 0: |
| other_spread = torch.tensor(0.0, device=anchors.device) |
| for i in range(self.n_layers - 1): |
| layer = self[f'layer_{i}'] |
| layer_anchors = layer['observer'].association.constellation.anchors |
| other_spread = other_spread + _geolip_spread_loss(layer_anchors) |
| other_spread = w_spread * other_spread / (self.n_layers - 1) |
| loss = loss + other_spread |
| ld['spread_other_layers'] = other_spread.detach() |
|
|
| ld['total'] = loss |
| return loss, ld |
|
|
|
|
| |
| |
| |
|
|
| def geo_transformer_esm2(name='geo_esm2', n_layers=6, **kw): |
| """Pre-configured for ESM-2 650M (d=1280).""" |
| return GeometricTransformer(name, d_model=1280, n_heads=16, |
| n_layers=n_layers, n_anchors=32, manifold_dim=256, |
| n_comp=8, d_comp=32, context_dim=128, quat_dim=64, **kw) |
|
|
| def geo_transformer_small(name='geo_small', n_layers=4, **kw): |
| """Small config for prototyping.""" |
| return GeometricTransformer(name, d_model=256, n_heads=8, |
| n_layers=n_layers, n_anchors=16, manifold_dim=128, |
| n_comp=4, d_comp=16, context_dim=64, quat_dim=32, **kw) |
|
|
| def geo_transformer_vision(name='geo_vit', n_layers=4, **kw): |
| """For scatter/SVD vision pipeline (patches as tokens).""" |
| return GeometricTransformer(name, d_model=384, n_heads=8, |
| n_layers=n_layers, n_anchors=32, manifold_dim=128, |
| n_comp=8, d_comp=16, context_dim=64, quat_dim=32, **kw) |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == '__main__': |
| print("Geometric Transformer β CM Validated β Self-Test") |
| print("=" * 60) |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| |
| model = geo_transformer_small('test_cm', n_layers=2) |
| if hasattr(model, 'network_to'): |
| model.network_to(device=device, strict=False) |
| else: |
| model = model.to(device) |
| total = model.param_report() |
|
|
| |
| B, L, D = 2, 32, 256 |
| x = torch.randn(B, L, D, device=device) |
| out, geos = model(x, return_geo_state=True) |
|
|
| assert out.shape == (B, L, D), f"Expected ({B},{L},{D}), got {out.shape}" |
| assert len(geos) == 2 |
| print(f"\n Input: ({B}, {L}, {D})") |
| print(f" Output: {out.shape}") |
| print(f" Geo states: {len(geos)} layers") |
|
|
| |
| for i, gs in enumerate(geos): |
| gi = gs['gate_info'] |
| cm_q = gs['cm_quality'] |
| gv = gs['gate_values'] |
| print(f"\n Layer {i} CM gate:") |
| print(f" active anchors: {gi['active'].item():.1f} / {model.n_anchors}") |
| print(f" gate mean: {gi['gate_mean'].item():.4f}") |
| print(f" cm_positive_frac: {gi['cm_positive_frac'].item():.3f}") |
| print(f" gate_values: {gv.shape} range=[{gv.min():.3f}, {gv.max():.3f}]") |
| print(f" cm_quality: {cm_q.shape} mean={cm_q.mean():.4f}") |
|
|
| |
| gr0 = geos[0]['geo_residual'] |
| gr1 = geos[1]['geo_residual'] |
| print(f"\n Geo residual stream:") |
| print(f" Layer 0: {gr0.shape} norm={gr0.norm(dim=-1).mean():.4f}") |
| print(f" Layer 1: {gr1.shape} norm={gr1.norm(dim=-1).mean():.4f}") |
|
|
| |
| geo_losses = model.geometric_losses() |
| print(f"\n Geometric regularization:") |
| for k, v in geo_losses.items(): |
| print(f" {k}: {v.item():.6f}") |
|
|
| |
| diag = model.anchor_diagnostics() |
| print(f"\n Anchor diagnostics:") |
| for layer_name, d in diag.items(): |
| print(f" {layer_name}:") |
| for k, v in d.items(): |
| print(f" {k}: {v:.4f}") |
|
|
| |
| print(f"\n Cayley rotations:") |
| for name, module in model.named_modules(): |
| if isinstance(module, CayleyOrthogonal): |
| R = module.get_rotation() |
| I = torch.eye(R.shape[0], device=R.device) |
| print(f" {name}: βRRα΅-Iβ={((R@R.T)-I).norm():.8f} det={torch.det(R):.4f}") |
|
|
| |
| print(f"\n Gradient flow test:") |
| model.zero_grad() |
| x_grad = torch.randn(B, L, D, device=device, requires_grad=True) |
| out_grad = model(x_grad) |
| loss = out_grad.sum() |
| loss.backward() |
|
|
| |
| for i in range(model.n_layers): |
| layer = model[f'layer_{i}'] |
| gate_grads = [p.grad is not None and p.grad.abs().sum() > 0 |
| for p in layer['cm_gate'].parameters()] |
| print(f" layer_{i} cm_gate grad: {'YES' if all(gate_grads) else 'NO'}") |
|
|
| |
| print(f"\n Training step simulation:") |
| optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) |
| optimizer.zero_grad() |
|
|
| x_train = torch.randn(B, L, D, device=device) |
| out_train, states = model(x_train, return_geo_state=True) |
| task_loss = out_train.mean() |
|
|
| geo_losses = model.geometric_losses() |
| total_loss = task_loss + geo_losses.get('geo_total', 0.0) |
| total_loss.backward() |
| optimizer.step() |
| print(f" task_loss: {task_loss.item():.4f}") |
| print(f" cv_loss: {geo_losses['cv'].item():.6f}") |
| print(f" spread_loss:{geo_losses['spread'].item():.6f}") |
| print(f" total: {total_loss.item():.4f}") |
|
|
| |
| print(f"\n Paired forward + observer loss:") |
| model.zero_grad() |
|
|
| x1 = torch.randn(B, L, D, device=device) |
| x2 = x1 + 0.1 * torch.randn_like(x1) |
| targets = torch.randint(0, 10, (B,), device=device) |
|
|
| output = model.forward_paired(x1, x2) |
| print(f" Output keys: {sorted(k for k in output if not k.startswith('geo_'))}") |
| for k in ['embedding', 'patchwork1', 'bridge1', 'assign1', 'tri1']: |
| print(f" {k}: {output[k].shape}") |
|
|
| |
| num_classes = 10 |
| head = nn.Linear(D, num_classes).to(device) |
|
|
| loss, ld = model.compute_loss(output, targets, head=head) |
| print(f"\n Three-domain loss breakdown:") |
| for k in ['loss_observer', 'loss_task', 'ce', 'nce_emb', 'nce_pw', |
| 'bridge', 'assign', 'assign_nce', 'nce_tri', 'attract', |
| 'cv', 'spread']: |
| if k in ld: |
| v = ld[k] |
| v = v.item() if isinstance(v, torch.Tensor) else v |
| print(f" {k:16s} = {v:.4f}") |
| for k in ['nce_emb_acc', 'nce_pw_acc', 'nce_tri_acc', 'bridge_acc', |
| 'assign_nce_acc', 'acc']: |
| if k in ld: |
| v = ld[k] |
| v = v if isinstance(v, float) else v.item() |
| print(f" {k:16s} = {v*100:.1f}%") |
| print(f" {'TOTAL':16s} = {loss.item():.4f}") |
|
|
| |
| loss.backward() |
| alive_base, dead_base = [], [] |
| for n, p in model.named_parameters(): |
| if p.grad is not None and p.grad.norm() > 0: |
| alive_base.append(n) |
| else: |
| dead_base.append(n) |
| print(f"\n Gradient flow: {len(alive_base)} params alive, {len(dead_base)} dead") |
| if dead_base: |
| print(f"\n DEAD parameters (base model, paired+observer):") |
| for n in dead_base: |
| print(f" {n}") |
|
|
| |
| |
| |
|
|
| print(f"\n{'='*60}") |
| print(f" WideRouter Compilation") |
| print(f"{'='*60}") |
|
|
| if _HAS_WIDE_ROUTER: |
| |
| router = WideRouter('test_router', strict=False) |
| router.attach('transformer', model) |
| router.register_tower('transformer') |
| router.network_to(device=device, strict=False) |
|
|
| |
| router.discover_towers() |
| print(f"\n Towers discovered: {router.tower_names}") |
| print(f" Analyzed: {router.objects.get('_analyzed', False)}") |
|
|
| try: |
| compiled_router = router.compile(mode='default') |
| print(f" WideRouter.compile(mode='default'): OK") |
| except Exception as e: |
| print(f" WideRouter.compile: {str(e)[:60]}") |
|
|
| |
| with torch.no_grad(): |
| out_via_router = router['transformer'](x) |
| print(f" Forward via router['transformer']: {out_via_router.shape} OK") |
|
|
| del router |
| else: |
| print(f"\n WideRouter: geofractal not installed") |
|
|
| print(f"\n{'='*60}") |
| print(f" PASSED β CM-validated pipeline operational") |
| print(f"{'='*60}") |
|
|
| |
| |
| |
|
|
| print(f"\n{'='*60}") |
| print(f" Flow Ensemble Integration") |
| print(f"{'='*60}") |
|
|
| del model, optimizer |
| torch.cuda.empty_cache() if device.type == 'cuda' else None |
|
|
| model_f = geo_transformer_small('test_flows', n_layers=2, |
| flow_keys=['quat_lite', 'velocity', 'orbital']) |
| if hasattr(model_f, 'network_to'): |
| model_f.network_to(device=device, strict=False) |
| else: |
| model_f = model_f.to(device) |
|
|
| total_f = model_f.param_report() |
| print(f"\n Total params (with flows): {total_f:,}") |
|
|
| print(f"\n Flow ensemble per layer:") |
| for i in range(model_f.n_layers): |
| layer = model_f[f'layer_{i}'] |
| if layer.has('flows'): |
| flows = layer['flows'] |
| names = flows.active_flow_names |
| params = sum(p.numel() for p in flows.parameters()) |
| print(f" layer_{i}: {names} ({params:,} params)") |
| else: |
| print(f" layer_{i}: no flows attached") |
|
|
| x_f = torch.randn(B, L, D, device=device) |
| out_f, geos_f = model_f(x_f, return_geo_state=True) |
| assert out_f.shape == (B, L, D) |
| print(f"\n Forward with flows: {out_f.shape} OK") |
|
|
| geo_ctx_0 = geos_f[0]['geo_ctx'] |
| print(f" Geo context shape: {geo_ctx_0.shape} norm={geo_ctx_0.norm(dim=-1).mean():.4f}") |
|
|
| print(f"\n Flow gradient test (out.sum().backward()):") |
| model_f.zero_grad() |
| x_fg = torch.randn(B, L, D, device=device, requires_grad=True) |
| out_fg = model_f(x_fg) |
| out_fg.sum().backward() |
|
|
| alive_simple, dead_simple = [], [] |
| for n, p in model_f.named_parameters(): |
| if p.grad is not None and p.grad.abs().sum() > 0: |
| alive_simple.append(n) |
| else: |
| dead_simple.append(n) |
| print(f" {len(alive_simple)} alive, {len(dead_simple)} dead") |
| if dead_simple: |
| print(f"\n DEAD parameters (out.sum):") |
| for n in dead_simple: |
| print(f" {n}") |
|
|
| print(f"\n Paired forward + observer loss (with flows):") |
| model_f.zero_grad() |
| x1_f = torch.randn(B, L, D, device=device) |
| x2_f = x1_f + 0.1 * torch.randn_like(x1_f) |
| targets_f = torch.randint(0, 10, (B,), device=device) |
|
|
| output_f = model_f.forward_paired(x1_f, x2_f) |
| head_f = nn.Linear(D, num_classes).to(device) |
| loss_f, ld_f = model_f.compute_loss(output_f, targets_f, head=head_f) |
| print(f" total loss: {loss_f.item():.4f}") |
| loss_f.backward() |
|
|
| alive_paired, dead_paired = [], [] |
| for n, p in model_f.named_parameters(): |
| if p.grad is not None and p.grad.abs().sum() > 0: |
| alive_paired.append(n) |
| else: |
| dead_paired.append(n) |
| print(f" {len(alive_paired)} alive, {len(dead_paired)} dead") |
| if dead_paired: |
| print(f"\n DEAD parameters (paired+observer):") |
| for n in dead_paired: |
| print(f" {n}") |
|
|
| print(f"\n Runtime flow management:") |
| layer0 = model_f['layer_0'] |
| flows_0 = layer0['flows'] |
| print(f" Before: {flows_0.active_flow_names}") |
|
|
| flows_0.attach_flow('alignment') |
| print(f" +alignment: {flows_0.active_flow_names}") |
|
|
| flows_0.detach_flow('velocity') |
| print(f" -velocity: {flows_0.active_flow_names}") |
|
|
| out_swapped = model_f(x_f) |
| assert out_swapped.shape == (B, L, D) |
| print(f" Forward after swap: {out_swapped.shape} OK") |
|
|
| layer1 = model_f['layer_1'] |
| if layer1.has('flows'): |
| for fn in list(layer1['flows'].active_flow_names): |
| key = fn.replace('flow_', '') |
| layer1['flows'].detach_flow(key) |
| print(f" Layer 1 after clear: {layer1['flows'].active_flow_names}") |
| out_partial = model_f(x_f) |
| assert out_partial.shape == (B, L, D) |
| print(f" Forward (L0 flows, L1 empty): {out_partial.shape} OK") |
|
|
| print(f"\n Backward compatibility (no flows):") |
| model_nf = geo_transformer_small('test_noflows', n_layers=2) |
| if hasattr(model_nf, 'network_to'): |
| model_nf.network_to(device=device, strict=False) |
| else: |
| model_nf = model_nf.to(device) |
| out_nf = model_nf(torch.randn(B, L, D, device=device)) |
| assert out_nf.shape == (B, L, D) |
| print(f" Forward (no flows): {out_nf.shape} OK") |
| for i in range(model_nf.n_layers): |
| assert not model_nf[f'layer_{i}'].has('flows'), f"layer_{i} should not have flows" |
| print(f" No flows attached: OK") |
| del model_nf |
|
|
| print(f"\n{'='*60}") |
| print(f" PASSED β CM-validated pipeline operational") |
| print(f" PASSED β Flow ensemble integration verified") |
| print(f" PASSED β Flow attach/detach verified") |
| print(f" PASSED β Backward compatibility verified") |
| print(f"{'='*60}") |