AbstractPhil's picture
Rename model.py to v7/model.py
01bcf3d verified
"""
Geometric Transformer β€” CM-Validated Pipeline
==================================================
Dual-stream transformer with CM-gated constellation observation,
quaternion composition, and per-layer Cayley alignment.
CM-validated pipeline changes:
- CM validity gate between association and curation (AnchorGate)
- 5-stream PositionGeometricContext: anchor + structural + history + quality + FLOW
- CM-conditioned geometric residual accumulation (replaces blind learned gate)
- Built-in geometric regularization (CV target + anchor spread)
- Decomposed observer pipeline: association β†’ CM gate β†’ gated curation
- Optional FlowEnsemble: multi-opinion geometric fusion (quat, velocity, orbital, etc.)
Pipeline per layer:
1. ManifoldProjection: h_i β†’ emb_i on S^(d-1) per position
2. ConstellationAssociation: emb_i β†’ raw triangulation, cos, assignment
3. CMValidatedGate: per-anchor CM validity β†’ gate_values (B*L, A)
4. Gated curation: patchwork reads tri * gate_values (validated only)
4.5 FlowEnsemble (optional): multi-opinion geometric predictions
5. PositionGeometricContext: 5 streams β†’ FiLM context (B, L, context_dim)
6. ContentAttention (Stream A): standard MHA
7. GeometricAttention (Stream B): FiLM(Q,K | geo_ctx), V pure
8. CayleyOrthogonal: align B β†’ A basis
9. QuaternionCompose: w=A, i=aligned_B, j=A-B, k=A*B
10. Decode + gated residual
11. CM-conditioned geometric residual write
Geometric regularization (call model.geometric_losses() during training):
- CV loss: anchor CV β†’ pentachoron band (0.20-0.23)
- Spread loss: prevent anchor collapse (penalize positive cosine)
These maintain the constellation in the regime where CM validation works.
Design principles from Ryan Spearman (ρ=0.309, 76/84 wins):
- FiLM on Q,K ONLY β€” geometry routes attention, V stays pure
- FiLM on individual arms BEFORE composition, not after
- Quaternion algebra as structural regularizer (non-commutative coupling)
- CayleyOrthogonal guarantees pure rotation (det=1 always)
- Never global average pool β€” per-position geometric context
Author: AbstractPhil + Claude Opus 4.6
License: Apache 2.0
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# ═══════════════════════════════════════════════════════════════════════════════
# IMPORTS β€” hard dependencies, no fallback stubs
# ═══════════════════════════════════════════════════════════════════════════════
from geolip_core.core.associate.constellation import (
ConstellationObserver, ConstellationAssociation, ConstellationCuration,
Constellation, init_anchors_repulsion,
)
from geolip_core.core.curate.gate import AnchorGate as _GeolipAnchorGate
from geolip_core.core.curate.flows import FlowEnsemble
import geolip_core.linalg as LA
from geolip_core.pipeline.observer import (
TorchComponent, BaseTower, Input, Curation, Distinction,
)
from geolip_core.core.distinguish.losses import (
observer_loss as _geolip_observer_loss,
ce_loss_paired as _geolip_ce_loss_paired,
spread_loss as _geolip_spread_loss,
)
# Optional: geofractal WideRouter for compilation
try:
from geofractal.router.wide_router import WideRouter
_HAS_WIDE_ROUTER = True
except ImportError:
_HAS_WIDE_ROUTER = False
# ═══════════════════════════════════════════════════════════════════════════════
# CAYLEY-MENGER VALIDITY β€” geometric quality measurement
# ═══════════════════════════════════════════════════════════════════════════════
def pairwise_distances_squared(points):
"""Batched pairwise squared distances. (B, N, D) β†’ (B, N, N)."""
gram = torch.bmm(points, points.transpose(1, 2))
diag = gram.diagonal(dim1=-2, dim2=-1)
return diag.unsqueeze(2) + diag.unsqueeze(1) - 2 * gram
def cayley_menger_det(points):
"""Cayley-Menger signed volumeΒ² for simplices. (B, K, D) β†’ (B,).
K = number of vertices (k+1 for a k-simplex).
Sign-corrected: positive = valid non-degenerate simplex.
"""
B, K, D = points.shape
d2 = pairwise_distances_squared(points)
M = torch.zeros(B, K + 1, K + 1, device=points.device, dtype=points.dtype)
M[:, 0, 1:] = 1.0
M[:, 1:, 0] = 1.0
M[:, 1:, 1:] = d2
raw = LA.det(M)
k = K - 1
sign = (-1.0) ** (k + 1)
return sign * raw
def anchor_neighborhood_cm(anchors, n_neighbors=3):
"""Precompute per-anchor CM quality from local neighborhood geometry.
Position-independent. O(A) determinant computations on small matrices.
Each anchor forms a simplex with its k nearest neighbor anchors.
The CM determinant measures local geometric quality β€” high volume means
the anchor neighborhood is well-conditioned for triangulation.
Args:
anchors: (A, D) normalized anchor positions on S^(d-1)
n_neighbors: neighbors per simplex
Returns:
quality: (A,) signed log-magnitude CM quality per anchor
nn_idx: (A, n_neighbors) neighbor indices
"""
A, D = anchors.shape
dists = torch.cdist(anchors.unsqueeze(0), anchors.unsqueeze(0)).squeeze(0)
# Mask self-distances without in-place mutation (compile-safe)
self_mask = torch.eye(A, device=anchors.device, dtype=anchors.dtype) * 1e12
dists = dists + self_mask
_, nn_idx = dists.topk(n_neighbors, largest=False) # (A, n_neighbors)
# Build simplices: [anchor_a, neighbor_1, ..., neighbor_k] β€” fully vectorized
simplices = torch.cat([
anchors.unsqueeze(1), # (A, 1, D)
anchors[nn_idx], # (A, n_neighbors, D)
], dim=1) # (A, K, D)
dets = cayley_menger_det(simplices) # (A,)
sign = dets.sign()
log_mag = torch.log(dets.abs() + 1e-12)
return sign * log_mag, nn_idx
# ═══════════════════════════════════════════════════════════════════════════════
# CM VALIDATED GATE β€” efficient anchor gating for transformer scale
# ═══════════════════════════════════════════════════════════════════════════════
class CMValidatedGate(nn.Module):
"""Anchor gate based on Cayley-Menger validity.
Efficient for transformer scale: anchor CM quality is precomputed O(AΒ²)
and CACHED (only recomputed on invalidate_cache()), then combined with
per-position proximity features through a learned gate.
The gate starts OPEN (bias=+2, sigmoidβ‰ˆ0.88) and learns to CLOSE on
geometrically invalid configurations. Architecture-before-loss: the gate
suppresses degenerate measurements structurally, not through a loss signal.
Gate features per (position, anchor):
- anchor_cm_quality: CM volume of anchor's local neighborhood (cached)
- cos_to_anchor: cosine similarity (position-dependent)
Args:
n_anchors: number of constellation anchors
n_neighbors: neighbors for CM simplex computation
"""
def __init__(self, n_anchors, n_neighbors=3):
super().__init__()
self.n_anchors = n_anchors
self.n_neighbors = n_neighbors
# Learned gate: [cm_quality, cos_sim] β†’ scalar gate
self.gate_proj = nn.Sequential(
nn.Linear(2, 16),
nn.GELU(),
nn.Linear(16, 1),
)
# Init OPEN β€” learn to close. sigmoid(2.0) β‰ˆ 0.88
# Small random weight so gradient flows back to gate_proj[0]
nn.init.normal_(self.gate_proj[2].weight, std=0.01)
nn.init.constant_(self.gate_proj[2].bias, 2.0)
# Anchor CM cache β€” invalidated after optimizer step
self._cached_cm_norm = None
def invalidate_cache(self):
"""Call after optimizer.step() to recompute anchor CM next forward."""
self._cached_cm_norm = None
def precompute(self, anchors):
"""Compute anchor CM norm OUTSIDE compile graph.
Called from layer forward before the compilable gate computation.
Idempotent: skips if cache is warm.
"""
if self._cached_cm_norm is not None:
return
with torch.no_grad():
anchor_cm, _ = anchor_neighborhood_cm(anchors, self.n_neighbors)
cm_std = anchor_cm.std().clamp(min=1e-8)
self._cached_cm_norm = ((anchor_cm - anchor_cm.mean()) / cm_std).detach()
def _compute_gate(self, anchor_cm_norm, tri):
"""Fully compilable β€” pure tensor ops, no linalg, no graph breaks."""
N, A = tri.shape
cos_sim = 1.0 - tri
features = torch.stack([
anchor_cm_norm.unsqueeze(0).expand(N, -1),
cos_sim,
], dim=-1)
gate_values = torch.sigmoid(self.gate_proj(features).squeeze(-1))
gate_info = {
'active': (gate_values.detach() > 0.5).float().sum(-1).mean(),
'gate_mean': gate_values.detach().mean(),
'cm_positive_frac': (anchor_cm_norm > 0).float().mean(),
}
return gate_values, gate_info
def forward(self, tri):
"""Fully compilable forward. Requires precompute() called first."""
return self._compute_gate(self._cached_cm_norm, tri)
# ═══════════════════════════════════════════════════════════════════════════════
# INFONCE MEMORY BANK β€” contrastive pressure on geometric residual
# ═══════════════════════════════════════════════════════════════════════════════
class GeoResidualBank(nn.Module):
"""Cross-stream contrastive memory bank (CLIP-style)."""
def __init__(self, proj_dim, bank_size=4096, temperature=0.1):
super().__init__()
self.proj_dim = proj_dim
self.bank_size = bank_size
self.temperature = temperature
self.register_buffer('queue', torch.randn(bank_size, proj_dim))
self.queue = F.normalize(self.queue, dim=-1)
self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def enqueue(self, keys):
B = keys.shape[0]
ptr = int(self.queue_ptr.item())
if ptr + B <= self.bank_size:
self.queue[ptr:ptr + B] = keys
else:
overflow = (ptr + B) - self.bank_size
self.queue[ptr:] = keys[:B - overflow]
self.queue[:overflow] = keys[B - overflow:]
self.queue_ptr[0] = (ptr + B) % self.bank_size
def forward(self, content_proj, geo_proj):
q = F.normalize(content_proj, dim=-1)
k_pos = F.normalize(geo_proj, dim=-1)
k_neg = self.queue.clone().detach()
pos_logits = (q * k_pos).sum(dim=-1, keepdim=True) / self.temperature
neg_logits = q @ k_neg.T / self.temperature
logits = torch.cat([pos_logits, neg_logits], dim=1)
labels = torch.zeros(q.shape[0], dtype=torch.long, device=q.device)
loss = F.cross_entropy(logits, labels)
with torch.no_grad():
acc = (logits.argmax(dim=1) == 0).float().mean()
return loss, acc
# ═══════════════════════════════════════════════════════════════════════════════
# PROVEN COMPONENTS β€” from Ryan Spearman (unchanged, tested)
# ═══════════════════════════════════════════════════════════════════════════════
class FiLMLayer(TorchComponent):
"""Feature-wise Linear Modulation. Near-identity-initialized.
gamma β‰ˆ 1 + 0.01Β·geo_ctx, beta β‰ˆ 0.01Β·geo_ctx at init.
Gradient flows through to geo_ctx from step 0.
"""
def __init__(self, name, feature_dim, context_dim):
super().__init__(name)
self.to_gamma = nn.Linear(context_dim, feature_dim)
self.to_beta = nn.Linear(context_dim, feature_dim)
nn.init.normal_(self.to_gamma.weight, std=0.01); nn.init.ones_(self.to_gamma.bias)
nn.init.normal_(self.to_beta.weight, std=0.01); nn.init.zeros_(self.to_beta.bias)
def forward(self, x, ctx):
return self.to_gamma(ctx) * x + self.to_beta(ctx)
class CayleyOrthogonal(TorchComponent):
"""Guaranteed SO(d) rotation via Cayley map. det(Q) = 1 always."""
def __init__(self, name, dim):
super().__init__(name)
self.dim = dim
self.A_upper = nn.Parameter(torch.zeros(dim * (dim - 1) // 2) * 0.01)
idx = torch.triu_indices(dim, dim, offset=1)
self.register_buffer('_triu_row', idx[0], persistent=False)
self.register_buffer('_triu_col', idx[1], persistent=False)
self.register_buffer('_eye', torch.eye(dim), persistent=False)
def get_rotation(self):
d = self.dim
A = torch.zeros(d, d, device=self.A_upper.device, dtype=self.A_upper.dtype)
A[self._triu_row, self._triu_col] = self.A_upper
A = A - A.T
return LA.solve(self._eye + A, self._eye - A)
def forward(self, x):
return x @ self.get_rotation().T
def quaternion_multiply_batched(q1, q2):
"""Hamilton product on (B, 4, D) tensors. Fully vectorized."""
w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3]
w2, x2, y2, z2 = q2[:, 0], q2[:, 1], q2[:, 2], q2[:, 3]
return torch.stack([
w1*w2 - x1*x2 - y1*y2 - z1*z2,
w1*x2 + x1*w2 + y1*z2 - z1*y2,
w1*y2 - x1*z2 + y1*w2 + z1*x2,
w1*z2 + x1*y2 - y1*x2 + z1*w2,
], dim=1)
class QuaternionCompose(TorchComponent):
"""Four-arm Hamilton product composition. Proven in GeoQuat head."""
def __init__(self, name, input_dim, quat_dim=64):
super().__init__(name)
self.quat_dim = quat_dim
self.proj_w = nn.Linear(input_dim, quat_dim)
self.proj_i = nn.Linear(input_dim, quat_dim)
self.proj_j = nn.Linear(input_dim, quat_dim)
self.proj_k = nn.Linear(input_dim, quat_dim)
self.rotation = nn.Parameter(torch.randn(1, 4, quat_dim) * 0.1)
@property
def output_dim(self):
return self.quat_dim * 4
def forward(self, arm_w, arm_i, arm_j, arm_k):
shape = arm_w.shape[:-1]
D = arm_w.shape[-1]
flat = arm_w.dim() > 2
if flat:
arm_w = arm_w.reshape(-1, D); arm_i = arm_i.reshape(-1, D)
arm_j = arm_j.reshape(-1, D); arm_k = arm_k.reshape(-1, D)
q = torch.stack([self.proj_w(arm_w), self.proj_i(arm_i),
self.proj_j(arm_j), self.proj_k(arm_k)], dim=1)
q = q / (q.norm(dim=1, keepdim=True) + 1e-8)
r = self.rotation.expand(q.shape[0], -1, -1)
r = r / (r.norm(dim=1, keepdim=True) + 1e-8)
composed = quaternion_multiply_batched(r, q)
composed = composed.reshape(q.shape[0], -1)
if flat:
composed = composed.reshape(*shape, -1)
return composed
# ═══════════════════════════════════════════════════════════════════════════════
# TRANSFORMER-SPECIFIC COMPONENTS
# ═══════════════════════════════════════════════════════════════════════════════
class ManifoldProjection(TorchComponent):
"""Input stage: project transformer hidden states to S^(d-1)."""
def __init__(self, name, d_model, manifold_dim):
super().__init__(name)
self.proj = nn.Linear(d_model, manifold_dim)
self.norm = nn.LayerNorm(manifold_dim)
def forward(self, hidden_states):
h = self.norm(self.proj(hidden_states))
return F.normalize(h, dim=-1)
class PositionGeometricContext(TorchComponent):
"""Curation stage: 5-stream fusion β†’ FiLM context.
Five streams:
anchor: cos_to_anchors + assignment + triangulation β€” WHERE on the manifold
structural: patchwork + embedding β€” WHAT the local geometry looks like
history: geo_residual from previous layers β€” WHAT prior layers observed
quality: CM gate values per anchor β€” HOW TRUSTWORTHY is this observation
flow: FlowEnsemble predictions β€” WHAT other mathematical lenses see
The flow stream starts at zero (zero-init) and learns to contribute.
Without flows attached, the 5th stream is zeros β€” equivalent to the
original 4-stream architecture.
"""
def __init__(self, name, n_anchors, pw_dim, manifold_dim, context_dim):
super().__init__(name)
self.context_dim = context_dim
self.pw_dim = pw_dim
self.manifold_dim = manifold_dim
# WHERE on the manifold
self.anchor_mlp = nn.Sequential(
nn.Linear(n_anchors * 3, context_dim), nn.GELU(), nn.LayerNorm(context_dim))
# WHAT the local geometry looks like
self.struct_mlp = nn.Sequential(
nn.Linear(pw_dim + manifold_dim, context_dim), nn.GELU(), nn.LayerNorm(context_dim))
# WHAT prior layers observed
self.history_mlp = nn.Sequential(
nn.Linear(pw_dim, context_dim), nn.GELU(), nn.LayerNorm(context_dim))
# HOW TRUSTWORTHY β€” full per-anchor gate profile
self.quality_mlp = nn.Sequential(
nn.Linear(n_anchors, context_dim), nn.GELU(), nn.LayerNorm(context_dim))
# FLOW OPINIONS β€” anchor-space flow ensemble [N, A] (same shape as gate_values)
# Small init: negligible contribution at start, nonzero gradient path
self.flow_mlp = nn.Sequential(
nn.Linear(n_anchors, context_dim), nn.GELU(), nn.LayerNorm(context_dim))
nn.init.normal_(self.flow_mlp[0].weight, std=0.01)
nn.init.zeros_(self.flow_mlp[0].bias)
# Fuse 5 streams
self.fuse = nn.Sequential(
nn.Linear(context_dim * 5, context_dim), nn.GELU(), nn.LayerNorm(context_dim))
def forward(self, obs_dict, gate_values=None, geo_residual=None, flow_output=None):
"""
Args:
obs_dict: from decomposed association + gated curation
gate_values: (N, A) CM gate values per anchor, or None
geo_residual: (N, pw_dim) accumulated context, or None for first layer
flow_output: (N, manifold_dim) flow ensemble prediction, or None
Returns:
(N, context_dim) geometric context for FiLM
"""
anchor_feats = torch.cat([
obs_dict['cos_to_anchors'],
obs_dict['assignment'],
obs_dict['triangulation'],
], dim=-1)
struct_feats = torch.cat([
obs_dict['patchwork'],
obs_dict['embedding'],
], dim=-1)
a = self.anchor_mlp(anchor_feats)
s = self.struct_mlp(struct_feats)
h = self.history_mlp(geo_residual) if geo_residual is not None else torch.zeros_like(a)
q = self.quality_mlp(gate_values) if gate_values is not None else torch.zeros_like(a)
f = self.flow_mlp(flow_output) if flow_output is not None else torch.zeros_like(a)
return self.fuse(torch.cat([a, s, h, q, f], dim=-1))
class GeometricAttention(TorchComponent):
"""Attention with FiLM from curated constellation. Stream B."""
def __init__(self, name, d_model, n_heads=8, context_dim=128, dropout=0.1):
super().__init__(name)
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.scale = self.head_dim ** -0.5
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.film_q = FiLMLayer(f'{name}_film_q', d_model, context_dim)
self.film_k = FiLMLayer(f'{name}_film_k', d_model, context_dim)
self.norm = nn.LayerNorm(d_model)
self.ffn1 = nn.Linear(d_model, d_model * 4)
self.film_ffn = FiLMLayer(f'{name}_film_ffn', d_model * 4, context_dim)
self.ffn2 = nn.Linear(d_model * 4, d_model)
self.ffn_drop = nn.Dropout(dropout)
self.ffn_norm = nn.LayerNorm(d_model)
def forward(self, x, geo_ctx, attn_mask=None, key_padding_mask=None):
B, L, D = x.shape
H, HD = self.n_heads, self.head_dim
Q = self.film_q(self.w_q(x), geo_ctx)
K = self.film_k(self.w_k(x), geo_ctx)
V = self.w_v(x)
Q = Q.view(B, L, H, HD).transpose(1, 2)
K = K.view(B, L, H, HD).transpose(1, 2)
V = V.view(B, L, H, HD).transpose(1, 2)
scores = (Q @ K.transpose(-2, -1)) * self.scale
if attn_mask is not None:
scores = scores + attn_mask
if key_padding_mask is not None:
scores = scores.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'))
attn_out = (self.dropout(F.softmax(scores, dim=-1)) @ V)
attn_out = attn_out.transpose(1, 2).reshape(B, L, D)
x = self.norm(x + self.w_o(attn_out))
h = F.gelu(self.ffn1(x))
h = self.film_ffn(h, geo_ctx)
x = self.ffn_norm(x + self.ffn_drop(self.ffn2(h)))
return x
class ContentAttention(TorchComponent):
"""Standard self-attention. Stream A. No geometric conditioning."""
def __init__(self, name, d_model, n_heads=8, dropout=0.1):
super().__init__(name)
self.attn = nn.MultiheadAttention(
d_model, n_heads, dropout=dropout, batch_first=True)
self.norm = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model * 4), nn.GELU(),
nn.Linear(d_model * 4, d_model), nn.Dropout(dropout))
self.ffn_norm = nn.LayerNorm(d_model)
def forward(self, x, attn_mask=None, key_padding_mask=None):
a, _ = self.attn(x, x, x, attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False)
x = self.norm(x + a)
x = self.ffn_norm(x + self.ffn(x))
return x
# ═══════════════════════════════════════════════════════════════════════════════
# LAYER β€” CM-validated dual-stream with constellation routing + flows
# ═══════════════════════════════════════════════════════════════════════════════
class GeometricTransformerLayer(BaseTower):
"""One layer of the geometric transformer (CM validated + flows).
Pipeline per layer:
1. ManifoldProjection: h β†’ emb on S^(d-1)
2. Association: emb β†’ raw triangulation, cos, assignment
3. CMValidatedGate: per-anchor CM validity β†’ gate_values
4. Gated curation: patchwork reads tri * gate_values
4.5 FlowEnsemble (optional): multi-opinion geometric predictions
5. PositionGeometricContext: 5 streams β†’ FiLM context
6. ContentAttention (Stream A): standard MHA
7. GeometricAttention (Stream B): FiLM(Q,K | geo_ctx)
8. CayleyOrthogonal: align B β†’ A
9. QuaternionCompose: w=A, i=aligned_B, j=A-B, k=A*B
10. Decode + gated residual
11. CM-conditioned geometric residual accumulation
Flows are optional, config-driven, and individually replaceable:
layer['flows'].attach_flow('alignment')
layer['flows'].detach_flow('velocity')
"""
def __init__(self, name, d_model, n_heads=8, n_anchors=32,
manifold_dim=256, n_comp=8, d_comp=32,
context_dim=128, quat_dim=64, dropout=0.1,
cm_neighbors=3, flow_keys=None, flow_fusion='weighted'):
super().__init__(name)
self.d_model = d_model
self.n_anchors = n_anchors
self.manifold_dim = manifold_dim
# 1. Project to manifold
self.attach('projection', ManifoldProjection(
f'{name}_proj', d_model, manifold_dim))
# 2. Constellation observer (association + curation β€” called decomposed)
self.attach('observer', ConstellationObserver(
dim=manifold_dim, n_anchors=n_anchors,
n_comp=n_comp, d_comp=d_comp))
# 3. CM validated gate β€” between association and curation
self.attach('cm_gate', CMValidatedGate(
n_anchors=n_anchors, n_neighbors=cm_neighbors))
# 3.5 Flow ensemble β€” optional multi-opinion geometric fusion
if flow_keys:
self.attach('flows', FlowEnsemble(
f'{name}_flows', manifold_dim, n_anchors,
flow_keys=flow_keys, fusion=flow_fusion))
# Blend weight: how much flow opinions influence curation
# Starts small β†’ flows fade in as they learn
self.flow_alpha = nn.Parameter(torch.tensor(0.01))
# 4. Fuse observation into FiLM context (5 streams)
pw_dim = self['observer'].curation.patchwork.output_dim
self.attach('context', PositionGeometricContext(
f'{name}_ctx', n_anchors, pw_dim, manifold_dim, context_dim))
# 5. Stream A: content
self.attach('content', ContentAttention(
f'{name}_content', d_model, n_heads, dropout))
# 6. Stream B: geometric
self.attach('geometric', GeometricAttention(
f'{name}_geo', d_model, n_heads, context_dim, dropout))
# 7. Cayley rotation: align B β†’ A
self.attach('rotation', CayleyOrthogonal(f'{name}_cayley', d_model))
# 8. Quaternion composition
self.attach('compose', QuaternionCompose(
f'{name}_quat', d_model, quat_dim))
# 9. Decode + output gate
self.attach('decode', nn.Sequential(
nn.Linear(quat_dim * 4, d_model), nn.GELU(), nn.LayerNorm(d_model)))
self.attach('gate', nn.Sequential(
nn.Linear(d_model * 2, d_model), nn.Sigmoid()))
# 10. Geometric residual projection (no learned gate β€” CM quality decides)
self._pw_dim = pw_dim
self.attach('geo_proj', nn.Sequential(
nn.Linear(pw_dim, pw_dim), nn.LayerNorm(pw_dim)))
def forward(self, x, geo_residual=None, attn_mask=None, key_padding_mask=None):
"""
Args:
x: (B, L, D) input hidden states
geo_residual: (B, L, pw_dim) accumulated geometric context,
or None for first layer
Returns:
x_out: (B, L, D) transformed hidden states
geo_residual_out: (B, L, pw_dim) updated geometric residual
geo_state: dict with full geometric state + CM diagnostics
"""
B, L, D = x.shape
# ════ 1. Project to manifold ════
emb = self['projection'](x) # (B, L, manifold_dim)
emb_flat = emb.reshape(B * L, -1)
# ════ 2. Association β€” raw triangulation ════
a_out = self['observer'].association(emb_flat)
# ════ 3. CM Gate β€” validate anchor measurements ════
anchors_n = F.normalize(
self['observer'].association.constellation.anchors, dim=-1)
# CM gate forward β€” precompute() must have been called before entering
# the compiled graph (by GeometricTransformer.precompute_cm_gates())
gate_values, gate_info = self['cm_gate'](a_out['distances'])
# ════ 4. Gated curation β€” patchwork reads validated triangulation ════
a_out_gated = dict(a_out)
# ════ 4.5 Flow ensemble β€” anchor-space geometric opinions ════
flow_opinion = None
if self.has('flows'):
flow_opinion = self['flows'](anchors_n, emb_flat, a_out['distances']) # [N, A]
# Blend flow opinion into triangulation: raw + alpha*(flow - raw)
# flow_alpha starts at 0.01 β†’ 99% raw, 1% flow opinion
# Gradient: observer_loss β†’ patchwork β†’ distances_weighted β†’ flow_opinion β†’ flows
alpha = self.flow_alpha.sigmoid()
blended_tri = a_out['distances'] + alpha * (flow_opinion - a_out['distances'])
a_out_gated['distances_weighted'] = blended_tri * gate_values
else:
a_out_gated['distances_weighted'] = a_out['distances'] * gate_values
c_out = self['observer'].curation.curate_full(a_out_gated, emb=emb_flat)
# Build observation dict for context
obs = {
'embedding': emb_flat,
'triangulation': a_out['distances'],
'cos_to_anchors': a_out['cos_to_anchors'],
'assignment': a_out['assignment'],
'nearest': a_out['nearest'],
'patchwork': c_out['patchwork'],
'bridge': c_out['bridge'],
}
# ════ 5. Build FiLM context β€” 5 streams ════
geo_res_flat = geo_residual.reshape(B * L, -1) if geo_residual is not None else None
geo_ctx_flat = self['context'](
obs, gate_values=gate_values, geo_residual=geo_res_flat,
flow_output=flow_opinion)
geo_ctx = geo_ctx_flat.reshape(B, L, -1)
# ════ 6. Stream A: content attention ════
a_out_stream = self['content'](
x, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
# ════ 7. Stream B: geometric attention ════
b_out = self['geometric'](
x, geo_ctx, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
# ════ 8. Cayley rotation: align B β†’ A ════
b_aligned = self['rotation'](b_out)
# ════ 9. Quaternion composition ════
composed = self['compose'](
arm_w=a_out_stream, arm_i=b_aligned,
arm_j=a_out_stream - b_aligned, arm_k=a_out_stream * b_aligned)
# ════ 10. Decode + gated residual ════
decoded = self['decode'](composed)
g = self['gate'](torch.cat([x, decoded], dim=-1))
x_out = g * decoded + (1 - g) * x
# ════ 11. CM-conditioned geometric residual accumulation ════
pw_validated = c_out['patchwork'].reshape(B, L, -1)
cm_quality = gate_values.mean(dim=-1).reshape(B, L, 1)
geo_update = self['geo_proj'](pw_validated)
if geo_residual is None:
geo_residual_out = cm_quality * geo_update
else:
geo_residual_out = geo_residual + cm_quality * geo_update
# ════ Build geo_state dict ════
def _unflatten(t):
if t is None:
return None
if t.dim() == 1:
return t.reshape(B, L)
return t.reshape(B, L, *t.shape[1:])
geo_state = {
'embedding': emb,
'geo_ctx': geo_ctx,
'triangulation': _unflatten(a_out['distances']),
'cos_to_anchors': _unflatten(a_out['cos_to_anchors']),
'assignment': _unflatten(a_out['assignment']),
'nearest': _unflatten(a_out['nearest']),
'patchwork': _unflatten(c_out['patchwork']),
'bridge': _unflatten(c_out['bridge']),
'gate_values': _unflatten(gate_values),
'gate_info': gate_info,
'cm_quality': cm_quality,
'content': a_out_stream,
'geometric': b_out,
'composed': composed,
'geo_residual': geo_residual_out,
'flow_opinion': _unflatten(flow_opinion) if flow_opinion is not None else None,
}
return x_out, geo_residual_out, geo_state
# ═══════════════════════════════════════════════════════════════════════════════
# FULL MODEL β€” stack of layers + geometric regularization
# ═══════════════════════════════════════════════════════════════════════════════
class GeometricTransformer(BaseTower):
"""Geometric Transformer β€” CM-validated dual-stream with optional flows.
Stack of GeometricTransformerLayers with:
- CM-gated observation at every layer
- Optional FlowEnsemble at every layer (config-driven)
- Cross-layer Cayley rotation on hidden states
- Built-in geometric regularization via geometric_losses()
"""
def __init__(self, name, d_model=512, n_heads=8, n_layers=4,
n_anchors=32, manifold_dim=256, n_comp=8, d_comp=32,
context_dim=128, quat_dim=64, dropout=0.1,
cross_layer_rotation=True, cm_neighbors=3,
nce_bank_size=4096, nce_temperature=0.1,
vocab_size=None, max_seq_len=2048,
flow_keys=None, flow_fusion='weighted'):
super().__init__(name)
self.d_model = d_model
self.n_layers = n_layers
self.n_anchors = n_anchors
self._pw_dim = n_comp * d_comp
if vocab_size is not None:
self.attach('embed', nn.Embedding(vocab_size, d_model))
self.attach('pos_embed', nn.Embedding(max_seq_len, d_model))
self.attach('head', nn.Linear(d_model, vocab_size, bias=False))
for i in range(n_layers):
self.attach(f'layer_{i}', GeometricTransformerLayer(
f'{name}_L{i}', d_model, n_heads, n_anchors,
manifold_dim, n_comp, d_comp, context_dim, quat_dim,
dropout, cm_neighbors,
flow_keys=flow_keys, flow_fusion=flow_fusion))
if cross_layer_rotation and n_layers > 1:
for i in range(n_layers - 1):
self.attach(f'cross_rot_{i}', CayleyOrthogonal(
f'{name}_xrot_{i}', d_model))
self.attach('final_norm', nn.LayerNorm(d_model))
# Cross-stream contrastive (CLIP-style)
if nce_bank_size > 0:
nce_proj_dim = 128
self.attach('nce_content_proj', nn.Sequential(
nn.Linear(d_model, nce_proj_dim),
nn.GELU(),
nn.Linear(nce_proj_dim, nce_proj_dim),
))
self.attach('nce_geo_proj', nn.Sequential(
nn.Linear(self._pw_dim, nce_proj_dim),
nn.GELU(),
nn.Linear(nce_proj_dim, nce_proj_dim),
))
self.attach('nce_bank', GeoResidualBank(
nce_proj_dim, bank_size=nce_bank_size,
temperature=nce_temperature))
self._config = dict(
d_model=d_model, n_heads=n_heads, n_layers=n_layers,
n_anchors=n_anchors, manifold_dim=manifold_dim,
n_comp=n_comp, d_comp=d_comp, context_dim=context_dim,
quat_dim=quat_dim, dropout=dropout,
cross_layer_rotation=cross_layer_rotation,
cm_neighbors=cm_neighbors, vocab_size=vocab_size,
nce_bank_size=nce_bank_size, nce_temperature=nce_temperature,
flow_keys=flow_keys, flow_fusion=flow_fusion,
)
@property
def config(self):
return self._config.copy()
def invalidate_caches(self):
"""Invalidate all CM gate caches. Call after optimizer.step()."""
for i in range(self.n_layers):
self[f'layer_{i}']['cm_gate'].invalidate_cache()
@torch.compiler.disable
def precompute_cm_gates(self):
"""Precompute CM gate anchor quality for all layers.
Must be called BEFORE the compiled forward pass. CUDA graph
capture cannot contain module attribute mutations (precompute
writes to self._cached_cm_norm). This runs outside the graph.
Idempotent: skips layers with warm caches.
"""
for i in range(self.n_layers):
layer = self[f'layer_{i}']
anchors_n = F.normalize(
layer['observer'].association.constellation.anchors, dim=-1)
layer['cm_gate'].precompute(anchors_n.detach())
def geometric_losses(self, cv_target=0.215, cv_weight=0.1, spread_weight=0.01):
"""Compute geometric regularization from current anchor geometry."""
total_cv = torch.tensor(0.0)
total_spread = torch.tensor(0.0)
n = 0
for i in range(self.n_layers):
layer = self[f'layer_{i}']
anchors = layer['observer'].association.constellation.anchors
anchors_n = F.normalize(anchors, dim=-1)
A = anchors_n.shape[0]
if n == 0:
total_cv = total_cv.to(anchors.device)
total_spread = total_spread.to(anchors.device)
cos = anchors_n @ anchors_n.T
idx = torch.triu_indices(A, A, offset=1, device=cos.device)
pairwise_dist = 1.0 - cos[idx[0], idx[1]]
cv = pairwise_dist.std() / (pairwise_dist.mean() + 1e-8)
total_cv = total_cv + (cv - cv_target).pow(2)
mask = ~torch.eye(A, dtype=torch.bool, device=cos.device)
total_spread = total_spread + F.relu(cos[mask]).mean()
n += 1
losses = {}
if n > 0:
losses['cv'] = cv_weight * total_cv / n
losses['spread'] = spread_weight * total_spread / n
losses['geo_total'] = losses['cv'] + losses['spread']
return losses
def infonce_loss(self, cls_index=0):
"""Cross-stream contrastive: content queries against decoupled geometry."""
if not self.has('nce_bank'):
return {}
hidden = getattr(self, '_last_hidden', None)
geo_residual = getattr(self, '_last_geo_residual', None)
if hidden is None or geo_residual is None:
return {}
content_cls = self['nce_content_proj'](hidden[:, cls_index])
geo_cls = self['nce_geo_proj'](geo_residual[:, cls_index].detach())
loss, acc = self['nce_bank'](content_cls, geo_cls)
return {'nce': loss, 'nce_acc': acc}
@torch.no_grad()
def update_nce_bank(self, cls_index=0):
"""Enqueue projected geo keys into bank. Call AFTER backward."""
if not self.has('nce_bank') or not self.has('nce_geo_proj'):
return
geo_residual = getattr(self, '_last_geo_residual', None)
if geo_residual is None:
return
geo_cls = self['nce_geo_proj'](geo_residual[:, cls_index].detach())
self['nce_bank'].enqueue(F.normalize(geo_cls, dim=-1))
def anchor_diagnostics(self):
"""Per-layer anchor health diagnostics."""
diag = {}
for i in range(self.n_layers):
layer = self[f'layer_{i}']
anchors = layer['observer'].association.constellation.anchors
anchors_n = F.normalize(anchors.detach(), dim=-1)
A = anchors_n.shape[0]
cos = anchors_n @ anchors_n.T
idx = torch.triu_indices(A, A, offset=1, device=cos.device)
pairwise = 1.0 - cos[idx[0], idx[1]]
cv = (pairwise.std() / (pairwise.mean() + 1e-8)).item()
with torch.no_grad():
anchor_cm, _ = anchor_neighborhood_cm(
anchors_n, layer['cm_gate'].n_neighbors)
diag[f'layer_{i}'] = {
'anchor_cv': cv,
'mean_pairwise_dist': pairwise.mean().item(),
'min_pairwise_dist': pairwise.min().item(),
'cm_positive_frac': (anchor_cm > 0).float().mean().item(),
'cm_mean': anchor_cm.mean().item(),
'cm_std': anchor_cm.std().item(),
}
return diag
def param_report(self):
total = 0
name = getattr(self, '_tower_name', self.__class__.__name__)
print(f"\n {name} β€” parameter report (CM-validated + flows)")
print(f" {'Component':<35s} {'Params':>12s}")
print(f" {'─'*35} {'─'*12}")
for cname, module in self.named_children():
n = sum(p.numel() for p in module.parameters())
total += n
print(f" {cname:<35s} {n:>12,}")
print(f" {'─'*35} {'─'*12}")
print(f" {'TOTAL':<35s} {total:>12,}")
return total
def forward(self, x, attn_mask=None, key_padding_mask=None,
return_geo_state=False):
if self.has('embed') and x.dtype in (torch.long, torch.int32, torch.int64):
pos = torch.arange(x.shape[1], device=x.device)
x = self['embed'](x) + self['pos_embed'](pos)
geo_states = []
has_xrot = self.has('cross_rot_0')
geo_residual = None
for i in range(self.n_layers):
x, geo_residual, geo_state = self[f'layer_{i}'](
x, geo_residual=geo_residual,
attn_mask=attn_mask, key_padding_mask=key_padding_mask)
if return_geo_state:
geo_states.append(geo_state)
if has_xrot and i < self.n_layers - 1:
x = self[f'cross_rot_{i}'](x)
self._last_geo_residual = geo_residual
self._last_hidden = x
x = self['final_norm'](x)
if self.has('head'):
x = self['head'](x)
return (x, geo_states) if return_geo_state else x
# ── Paired forward + observer loss ──────────────────────────────
def _run_view(self, x, attn_mask=None, key_padding_mask=None):
"""Run one view through the full pipeline.
Retains ALL layers' geo_states β€” every layer needs gradient.
"""
has_xrot = self.has('cross_rot_0')
geo_residual = None
if self.has('embed') and x.dtype in (torch.long, torch.int32, torch.int64):
pos = torch.arange(x.shape[1], device=x.device)
x = self['embed'](x) + self['pos_embed'](pos)
geo_states = []
for i in range(self.n_layers):
x, geo_residual, geo_state = self[f'layer_{i}'](
x, geo_residual=geo_residual,
attn_mask=attn_mask, key_padding_mask=key_padding_mask)
geo_states.append(geo_state)
if has_xrot and i < self.n_layers - 1:
x = self[f'cross_rot_{i}'](x)
x = self['final_norm'](x)
return x, geo_states
def forward_paired(self, x1, x2, cls_index=0,
attn_mask=None, key_padding_mask=None):
"""Dual-view forward for observer loss training.
Observer loss reads FINAL layer's observations (coherent space).
Non-final layers get gradient through the geo_residual stream
(FiLM β†’ context β†’ history_mlp β†’ geo_residual β†’ earlier layers).
All layers' computation graphs are retained by _run_view.
"""
B = x1.shape[0]
x_cat = torch.cat([x1, x2], dim=0)
feat_cat, geo_states = self._run_view(x_cat, attn_mask, key_padding_mask)
c = cls_index
gs = geo_states[-1] # final layer β€” coherent representation space
return {
'embedding': gs['embedding'][:B, c],
'embedding_aug': gs['embedding'][B:, c],
'patchwork1': gs['patchwork'][:B, c],
'patchwork1_aug': gs['patchwork'][B:, c],
'bridge1': gs['bridge'][:B, c],
'bridge2': gs['bridge'][B:, c],
'assign1': gs['assignment'][:B, c],
'assign2': gs['assignment'][B:, c],
'cos1': gs['cos_to_anchors'][:B, c],
'tri1': gs['triangulation'][:B, c],
'tri2': gs['triangulation'][B:, c],
'features1': feat_cat[:B],
'features2': feat_cat[B:],
'gate_values1': gs['gate_values'][:B, c],
'gate_values2': gs['gate_values'][B:, c],
'cm_quality1': gs['cm_quality'][:B],
'cm_quality2': gs['cm_quality'][B:],
}
def compute_loss(self, output, targets, cls_index=0,
w_ce=1.0, head=None, **loss_kwargs):
final_layer = self[f'layer_{self.n_layers - 1}']
anchors = final_layer['observer'].association.constellation.anchors
obs_loss, ld = _geolip_observer_loss(
output, anchors=anchors, targets=targets,
**loss_kwargs)
if head is not None:
feat1 = output['features1'][:, cls_index]
feat2 = output['features2'][:, cls_index]
logits1 = head(feat1)
logits2 = head(feat2)
l_ce, acc = _geolip_ce_loss_paired(logits1, logits2, targets)
ld['ce'], ld['acc'] = l_ce, acc
ld['logits'] = logits1
loss = w_ce * l_ce + obs_loss
ld['loss_task'] = l_ce.detach()
else:
loss = obs_loss
ld['loss_observer'] = obs_loss.detach()
w_spread = loss_kwargs.get('w_spread', 0.01)
if self.n_layers > 1 and w_spread > 0:
other_spread = torch.tensor(0.0, device=anchors.device)
for i in range(self.n_layers - 1):
layer = self[f'layer_{i}']
layer_anchors = layer['observer'].association.constellation.anchors
other_spread = other_spread + _geolip_spread_loss(layer_anchors)
other_spread = w_spread * other_spread / (self.n_layers - 1)
loss = loss + other_spread
ld['spread_other_layers'] = other_spread.detach()
ld['total'] = loss
return loss, ld
# ═══════════════════════════════════════════════════════════════════════════════
# FACTORIES
# ═══════════════════════════════════════════════════════════════════════════════
def geo_transformer_esm2(name='geo_esm2', n_layers=6, **kw):
"""Pre-configured for ESM-2 650M (d=1280)."""
return GeometricTransformer(name, d_model=1280, n_heads=16,
n_layers=n_layers, n_anchors=32, manifold_dim=256,
n_comp=8, d_comp=32, context_dim=128, quat_dim=64, **kw)
def geo_transformer_small(name='geo_small', n_layers=4, **kw):
"""Small config for prototyping."""
return GeometricTransformer(name, d_model=256, n_heads=8,
n_layers=n_layers, n_anchors=16, manifold_dim=128,
n_comp=4, d_comp=16, context_dim=64, quat_dim=32, **kw)
def geo_transformer_vision(name='geo_vit', n_layers=4, **kw):
"""For scatter/SVD vision pipeline (patches as tokens)."""
return GeometricTransformer(name, d_model=384, n_heads=8,
n_layers=n_layers, n_anchors=32, manifold_dim=128,
n_comp=8, d_comp=16, context_dim=64, quat_dim=32, **kw)
# ═══════════════════════════════════════════════════════════════════════════════
# SELF-TEST
# ═══════════════════════════════════════════════════════════════════════════════
if __name__ == '__main__':
print("Geometric Transformer β€” CM Validated β€” Self-Test")
print("=" * 60)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# ── Build small model ──
model = geo_transformer_small('test_cm', n_layers=2)
if hasattr(model, 'network_to'):
model.network_to(device=device, strict=False)
else:
model = model.to(device)
total = model.param_report()
# ── Forward pass ──
B, L, D = 2, 32, 256
x = torch.randn(B, L, D, device=device)
out, geos = model(x, return_geo_state=True)
assert out.shape == (B, L, D), f"Expected ({B},{L},{D}), got {out.shape}"
assert len(geos) == 2
print(f"\n Input: ({B}, {L}, {D})")
print(f" Output: {out.shape}")
print(f" Geo states: {len(geos)} layers")
# ── Verify CM gate is active ──
for i, gs in enumerate(geos):
gi = gs['gate_info']
cm_q = gs['cm_quality']
gv = gs['gate_values']
print(f"\n Layer {i} CM gate:")
print(f" active anchors: {gi['active'].item():.1f} / {model.n_anchors}")
print(f" gate mean: {gi['gate_mean'].item():.4f}")
print(f" cm_positive_frac: {gi['cm_positive_frac'].item():.3f}")
print(f" gate_values: {gv.shape} range=[{gv.min():.3f}, {gv.max():.3f}]")
print(f" cm_quality: {cm_q.shape} mean={cm_q.mean():.4f}")
# ── Verify geo_residual continuity ──
gr0 = geos[0]['geo_residual']
gr1 = geos[1]['geo_residual']
print(f"\n Geo residual stream:")
print(f" Layer 0: {gr0.shape} norm={gr0.norm(dim=-1).mean():.4f}")
print(f" Layer 1: {gr1.shape} norm={gr1.norm(dim=-1).mean():.4f}")
# ── Geometric losses ──
geo_losses = model.geometric_losses()
print(f"\n Geometric regularization:")
for k, v in geo_losses.items():
print(f" {k}: {v.item():.6f}")
# ── Anchor diagnostics ──
diag = model.anchor_diagnostics()
print(f"\n Anchor diagnostics:")
for layer_name, d in diag.items():
print(f" {layer_name}:")
for k, v in d.items():
print(f" {k}: {v:.4f}")
# ── Verify Cayley rotations ──
print(f"\n Cayley rotations:")
for name, module in model.named_modules():
if isinstance(module, CayleyOrthogonal):
R = module.get_rotation()
I = torch.eye(R.shape[0], device=R.device)
print(f" {name}: β€–RRα΅€-Iβ€–={((R@R.T)-I).norm():.8f} det={torch.det(R):.4f}")
# ── Gradient flow through CM gate ──
print(f"\n Gradient flow test:")
model.zero_grad()
x_grad = torch.randn(B, L, D, device=device, requires_grad=True)
out_grad = model(x_grad)
loss = out_grad.sum()
loss.backward()
# Check gate_proj has gradients
for i in range(model.n_layers):
layer = model[f'layer_{i}']
gate_grads = [p.grad is not None and p.grad.abs().sum() > 0
for p in layer['cm_gate'].parameters()]
print(f" layer_{i} cm_gate grad: {'YES' if all(gate_grads) else 'NO'}")
# ── Training step simulation ──
print(f"\n Training step simulation:")
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
optimizer.zero_grad()
x_train = torch.randn(B, L, D, device=device)
out_train, states = model(x_train, return_geo_state=True)
task_loss = out_train.mean() # dummy
geo_losses = model.geometric_losses()
total_loss = task_loss + geo_losses.get('geo_total', 0.0)
total_loss.backward()
optimizer.step()
print(f" task_loss: {task_loss.item():.4f}")
print(f" cv_loss: {geo_losses['cv'].item():.6f}")
print(f" spread_loss:{geo_losses['spread'].item():.6f}")
print(f" total: {total_loss.item():.4f}")
# ── Paired forward + observer loss ──
print(f"\n Paired forward + observer loss:")
model.zero_grad()
x1 = torch.randn(B, L, D, device=device)
x2 = x1 + 0.1 * torch.randn_like(x1) # view 2 = slight perturbation
targets = torch.randint(0, 10, (B,), device=device)
output = model.forward_paired(x1, x2)
print(f" Output keys: {sorted(k for k in output if not k.startswith('geo_'))}")
for k in ['embedding', 'patchwork1', 'bridge1', 'assign1', 'tri1']:
print(f" {k}: {output[k].shape}")
# Task head for CE
num_classes = 10
head = nn.Linear(D, num_classes).to(device)
loss, ld = model.compute_loss(output, targets, head=head)
print(f"\n Three-domain loss breakdown:")
for k in ['loss_observer', 'loss_task', 'ce', 'nce_emb', 'nce_pw',
'bridge', 'assign', 'assign_nce', 'nce_tri', 'attract',
'cv', 'spread']:
if k in ld:
v = ld[k]
v = v.item() if isinstance(v, torch.Tensor) else v
print(f" {k:16s} = {v:.4f}")
for k in ['nce_emb_acc', 'nce_pw_acc', 'nce_tri_acc', 'bridge_acc',
'assign_nce_acc', 'acc']:
if k in ld:
v = ld[k]
v = v if isinstance(v, float) else v.item()
print(f" {k:16s} = {v*100:.1f}%")
print(f" {'TOTAL':16s} = {loss.item():.4f}")
# Verify backward through observer loss
loss.backward()
alive_base, dead_base = [], []
for n, p in model.named_parameters():
if p.grad is not None and p.grad.norm() > 0:
alive_base.append(n)
else:
dead_base.append(n)
print(f"\n Gradient flow: {len(alive_base)} params alive, {len(dead_base)} dead")
if dead_base:
print(f"\n DEAD parameters (base model, paired+observer):")
for n in dead_base:
print(f" {n}")
# ══════════════════════════════════════════════════════════════
# WIDE ROUTER COMPILATION
# ══════════════════════════════════════════════════════════════
print(f"\n{'='*60}")
print(f" WideRouter Compilation")
print(f"{'='*60}")
if _HAS_WIDE_ROUTER:
# Wrap transformer in WideRouter (same pattern as GeoViTClassifier)
router = WideRouter('test_router', strict=False)
router.attach('transformer', model)
router.register_tower('transformer')
router.network_to(device=device, strict=False)
# Discover towers and compile
router.discover_towers()
print(f"\n Towers discovered: {router.tower_names}")
print(f" Analyzed: {router.objects.get('_analyzed', False)}")
try:
compiled_router = router.compile(mode='default')
print(f" WideRouter.compile(mode='default'): OK")
except Exception as e:
print(f" WideRouter.compile: {str(e)[:60]}")
# Forward through the registered tower directly
with torch.no_grad():
out_via_router = router['transformer'](x)
print(f" Forward via router['transformer']: {out_via_router.shape} OK")
del router
else:
print(f"\n WideRouter: geofractal not installed")
print(f"\n{'='*60}")
print(f" PASSED β€” CM-validated pipeline operational")
print(f"{'='*60}")
# ══════════════════════════════════════════════════════════════
# FLOW ENSEMBLE INTEGRATION TESTS
# ══════════════════════════════════════════════════════════════
print(f"\n{'='*60}")
print(f" Flow Ensemble Integration")
print(f"{'='*60}")
del model, optimizer
torch.cuda.empty_cache() if device.type == 'cuda' else None
model_f = geo_transformer_small('test_flows', n_layers=2,
flow_keys=['quat_lite', 'velocity', 'orbital'])
if hasattr(model_f, 'network_to'):
model_f.network_to(device=device, strict=False)
else:
model_f = model_f.to(device)
total_f = model_f.param_report()
print(f"\n Total params (with flows): {total_f:,}")
print(f"\n Flow ensemble per layer:")
for i in range(model_f.n_layers):
layer = model_f[f'layer_{i}']
if layer.has('flows'):
flows = layer['flows']
names = flows.active_flow_names
params = sum(p.numel() for p in flows.parameters())
print(f" layer_{i}: {names} ({params:,} params)")
else:
print(f" layer_{i}: no flows attached")
x_f = torch.randn(B, L, D, device=device)
out_f, geos_f = model_f(x_f, return_geo_state=True)
assert out_f.shape == (B, L, D)
print(f"\n Forward with flows: {out_f.shape} OK")
geo_ctx_0 = geos_f[0]['geo_ctx']
print(f" Geo context shape: {geo_ctx_0.shape} norm={geo_ctx_0.norm(dim=-1).mean():.4f}")
print(f"\n Flow gradient test (out.sum().backward()):")
model_f.zero_grad()
x_fg = torch.randn(B, L, D, device=device, requires_grad=True)
out_fg = model_f(x_fg)
out_fg.sum().backward()
alive_simple, dead_simple = [], []
for n, p in model_f.named_parameters():
if p.grad is not None and p.grad.abs().sum() > 0:
alive_simple.append(n)
else:
dead_simple.append(n)
print(f" {len(alive_simple)} alive, {len(dead_simple)} dead")
if dead_simple:
print(f"\n DEAD parameters (out.sum):")
for n in dead_simple:
print(f" {n}")
print(f"\n Paired forward + observer loss (with flows):")
model_f.zero_grad()
x1_f = torch.randn(B, L, D, device=device)
x2_f = x1_f + 0.1 * torch.randn_like(x1_f)
targets_f = torch.randint(0, 10, (B,), device=device)
output_f = model_f.forward_paired(x1_f, x2_f)
head_f = nn.Linear(D, num_classes).to(device)
loss_f, ld_f = model_f.compute_loss(output_f, targets_f, head=head_f)
print(f" total loss: {loss_f.item():.4f}")
loss_f.backward()
alive_paired, dead_paired = [], []
for n, p in model_f.named_parameters():
if p.grad is not None and p.grad.abs().sum() > 0:
alive_paired.append(n)
else:
dead_paired.append(n)
print(f" {len(alive_paired)} alive, {len(dead_paired)} dead")
if dead_paired:
print(f"\n DEAD parameters (paired+observer):")
for n in dead_paired:
print(f" {n}")
print(f"\n Runtime flow management:")
layer0 = model_f['layer_0']
flows_0 = layer0['flows']
print(f" Before: {flows_0.active_flow_names}")
flows_0.attach_flow('alignment')
print(f" +alignment: {flows_0.active_flow_names}")
flows_0.detach_flow('velocity')
print(f" -velocity: {flows_0.active_flow_names}")
out_swapped = model_f(x_f)
assert out_swapped.shape == (B, L, D)
print(f" Forward after swap: {out_swapped.shape} OK")
layer1 = model_f['layer_1']
if layer1.has('flows'):
for fn in list(layer1['flows'].active_flow_names):
key = fn.replace('flow_', '')
layer1['flows'].detach_flow(key)
print(f" Layer 1 after clear: {layer1['flows'].active_flow_names}")
out_partial = model_f(x_f)
assert out_partial.shape == (B, L, D)
print(f" Forward (L0 flows, L1 empty): {out_partial.shape} OK")
print(f"\n Backward compatibility (no flows):")
model_nf = geo_transformer_small('test_noflows', n_layers=2)
if hasattr(model_nf, 'network_to'):
model_nf.network_to(device=device, strict=False)
else:
model_nf = model_nf.to(device)
out_nf = model_nf(torch.randn(B, L, D, device=device))
assert out_nf.shape == (B, L, D)
print(f" Forward (no flows): {out_nf.shape} OK")
for i in range(model_nf.n_layers):
assert not model_nf[f'layer_{i}'].has('flows'), f"layer_{i} should not have flows"
print(f" No flows attached: OK")
del model_nf
print(f"\n{'='*60}")
print(f" PASSED β€” CM-validated pipeline operational")
print(f" PASSED β€” Flow ensemble integration verified")
print(f" PASSED β€” Flow attach/detach verified")
print(f" PASSED β€” Backward compatibility verified")
print(f"{'='*60}")