| """ |
| geolip.flows β Multi-flow ensemble for constellation geometry. |
| |
| Each flow predicts the same geometric output using a different mathematical |
| formulation. The ensemble fuses predictions based on learned confidence. |
| |
| Flows: |
| QuaternionFlow β Full MHA quaternion rotation (existing, heavyweight) |
| QuaternionLiteFlow β Staged quaternion with lighter spectral computation |
| VelocityFlow β Angular velocity dq/dt on the tangent bundle |
| MagnitudeFlow β Flow magnitude via Gram eigenvalue spectrum |
| OrbitalFlow β Omega-based orbital resonance using FL eigh |
| AlignmentFlow β SVD alignment via Procrustes rotation |
| |
| Architecture: |
| Each flow: same input (anchors [B,k,d], queries [B,n,d]) β output [B,n,d] |
| Ensemble: weighted fusion with learned per-flow confidence |
| |
| Usage: |
| from geolip.flows import FlowEnsemble, OrbitalFlow, AlignmentFlow |
| |
| ensemble = FlowEnsemble( |
| flows=[OrbitalFlow(d=256, k=128), AlignmentFlow(d=256, k=128)], |
| d_model=256, |
| ) |
| output = ensemble(anchors, queries) # [B, n, d] |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch import Tensor |
| from typing import List, Optional, Tuple |
|
|
| |
| |
| try: |
| import geolip_core.linalg as LA |
| except ImportError: |
| import torch.linalg as LA |
|
|
|
|
| |
| |
| |
|
|
| class BaseFlow(nn.Module): |
| """Base class for all geometric flows. |
| |
| All flows share the same interface: |
| Input: anchors [B, k, d], queries [B, n, d] |
| Output: prediction [B, n, d], confidence [B, n, 1] |
| |
| Subclasses implement _flow() with their specific math. |
| """ |
| def __init__(self, d_model: int, n_anchors: int, name: str = 'base'): |
| super().__init__() |
| self.d_model = d_model |
| self.n_anchors = n_anchors |
| self.name = name |
| |
| self.confidence = nn.Sequential( |
| nn.Linear(d_model, d_model // 4), |
| nn.GELU(), |
| nn.Linear(d_model // 4, 1), |
| ) |
|
|
| def forward(self, anchors: Tensor, queries: Tensor) -> Tuple[Tensor, Tensor]: |
| """ |
| Args: |
| anchors: [B, k, d] constellation anchor points |
| queries: [B, n, d] query embeddings |
| |
| Returns: |
| prediction: [B, n, d] geometric prediction |
| confidence: [B, n, 1] per-query confidence score |
| """ |
| pred = self._flow(anchors, queries) |
| conf = torch.sigmoid(self.confidence(pred)) |
| return pred, conf |
|
|
| def _flow(self, anchors: Tensor, queries: Tensor) -> Tensor: |
| raise NotImplementedError |
|
|
|
|
| |
| |
| |
|
|
| class QuaternionFlow(BaseFlow): |
| """Full multi-head attention with quaternion geometric rotation. |
| |
| Computes query-anchor attention, extracts rotation quaternion from |
| attention-weighted anchor geometry, applies rotation to queries. |
| Heavyweight β the full-fidelity path. |
| """ |
| def __init__(self, d_model: int, n_anchors: int, n_heads: int = 4): |
| super().__init__(d_model, n_anchors, name='quaternion') |
| self.n_heads = n_heads |
| self.head_dim = d_model // n_heads |
| self.q_proj = nn.Linear(d_model, d_model) |
| self.k_proj = nn.Linear(d_model, d_model) |
| self.v_proj = nn.Linear(d_model, d_model) |
| self.out_proj = nn.Linear(d_model, d_model) |
| |
| self.quat_proj = nn.Linear(d_model, 4) |
|
|
| def _flow(self, anchors, queries): |
| B, n, d = queries.shape |
| k = anchors.shape[1] |
| h = self.n_heads; hd = self.head_dim |
|
|
| Q = self.q_proj(queries).view(B, n, h, hd).transpose(1, 2) |
| K = self.k_proj(anchors).view(B, k, h, hd).transpose(1, 2) |
| V = self.v_proj(anchors).view(B, k, h, hd).transpose(1, 2) |
|
|
| attn = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(hd) |
| attn = F.softmax(attn, dim=-1) |
| ctx = torch.matmul(attn, V).transpose(1, 2).reshape(B, n, d) |
|
|
| |
| q = self.quat_proj(ctx) |
| q = F.normalize(q, dim=-1) |
| rotated = self._quat_rotate(queries, q) |
| return self.out_proj(ctx + rotated) |
|
|
| def _quat_rotate(self, v, q): |
| """Apply quaternion rotation to vectors. q: [B,n,4], v: [B,n,d].""" |
| |
| w, x, y, z = q[..., 0:1], q[..., 1:2], q[..., 2:3], q[..., 3:4] |
| v3 = v[..., :3] |
| |
| t = 2.0 * torch.cross(torch.cat([x, y, z], dim=-1), v3, dim=-1) |
| v3_rot = v3 + w * t + torch.cross(torch.cat([x, y, z], dim=-1), t, dim=-1) |
| if v.shape[-1] > 3: |
| return torch.cat([v3_rot, v[..., 3:]], dim=-1) |
| return v3_rot |
|
|
|
|
| |
| |
| |
|
|
| class QuaternionLiteFlow(BaseFlow): |
| """Lightweight quaternion prediction without full MHA. |
| |
| Uses anchor centroid + query projection to predict rotation directly. |
| Much lighter than full QuaternionFlow β trades attention resolution |
| for speed. |
| """ |
| def __init__(self, d_model: int, n_anchors: int): |
| super().__init__(d_model, n_anchors, name='quat_lite') |
| self.anchor_compress = nn.Linear(d_model, d_model) |
| self.query_proj = nn.Linear(d_model, d_model) |
| self.quat_head = nn.Sequential( |
| nn.Linear(d_model * 2, d_model), |
| nn.GELU(), |
| nn.Linear(d_model, 4), |
| ) |
| self.out_proj = nn.Linear(d_model, d_model) |
|
|
| def _flow(self, anchors, queries): |
| B, n, d = queries.shape |
| |
| anchor_ctx = self.anchor_compress(anchors.mean(dim=1, keepdim=True)) |
| anchor_ctx = anchor_ctx.expand(B, n, d) |
|
|
| q_proj = self.query_proj(queries) |
| combined = torch.cat([q_proj, anchor_ctx], dim=-1) |
|
|
| q = F.normalize(self.quat_head(combined), dim=-1) |
| rotated = self._quat_rotate_simple(queries, q) |
| return self.out_proj(rotated) |
|
|
| def _quat_rotate_simple(self, v, q): |
| w, xyz = q[..., 0:1], q[..., 1:4] |
| v3 = v[..., :3] |
| t = 2.0 * torch.cross(xyz, v3, dim=-1) |
| v3_rot = v3 + w * t + torch.cross(xyz, t, dim=-1) |
| if v.shape[-1] > 3: |
| return torch.cat([v3_rot, v[..., 3:]], dim=-1) |
| return v3_rot |
|
|
|
|
| |
| |
| |
|
|
| class VelocityFlow(BaseFlow): |
| """Angular velocity flow on the tangent space of the constellation. |
| |
| Models dq/dt: the rate of change of the query embedding induced by |
| the anchor geometry. Predicts velocity, integrates with Euler step. |
| |
| The velocity is tangent to the hypersphere at each query point. |
| """ |
| def __init__(self, d_model: int, n_anchors: int): |
| super().__init__(d_model, n_anchors, name='velocity') |
| |
| self.anchor_proj = nn.Linear(d_model, d_model) |
| self.query_proj = nn.Linear(d_model, d_model) |
| self.vel_head = nn.Sequential( |
| nn.Linear(d_model, d_model), |
| nn.GELU(), |
| nn.Linear(d_model, d_model), |
| ) |
| self.dt = nn.Parameter(torch.tensor(0.1)) |
|
|
| def _flow(self, anchors, queries): |
| B, n, d = queries.shape |
| |
| a_proj = self.anchor_proj(anchors) |
| q_proj = self.query_proj(queries) |
|
|
| |
| sim = torch.bmm(q_proj, a_proj.transpose(-2, -1)) |
| weights = F.softmax(sim / math.sqrt(d), dim=-1) |
| direction = torch.bmm(weights, a_proj) |
|
|
| |
| velocity = self.vel_head(direction - q_proj) |
|
|
| |
| q_norm = F.normalize(queries, dim=-1) |
| radial = (velocity * q_norm).sum(dim=-1, keepdim=True) * q_norm |
| tangent_vel = velocity - radial |
|
|
| |
| return queries + self.dt * tangent_vel |
|
|
|
|
| |
| |
| |
|
|
| class MagnitudeFlow(BaseFlow): |
| """Flow based on the Gram matrix eigenvalue magnitude spectrum. |
| |
| Computes the anchor Gram matrix, extracts eigenvalues via FL eigh, |
| uses the spectral profile to modulate query embeddings. |
| |
| The eigenvalue magnitudes encode the constellation's energy distribution |
| across geometric modes. |
| """ |
| def __init__(self, d_model: int, n_anchors: int): |
| super().__init__(d_model, n_anchors, name='magnitude') |
| |
| self.geom_dim = min(n_anchors, 12) |
| self.anchor_proj = nn.Linear(d_model, self.geom_dim) |
| |
| self.spec_proj = nn.Sequential( |
| nn.Linear(self.geom_dim, d_model), |
| nn.GELU(), |
| nn.Linear(d_model, d_model), |
| ) |
| self.query_proj = nn.Linear(d_model, d_model) |
| self.gate = nn.Linear(d_model * 2, d_model) |
|
|
| def _flow(self, anchors, queries): |
| B, n, d = queries.shape |
| |
| a_geom = self.anchor_proj(anchors) |
|
|
| |
| G = torch.bmm(a_geom.transpose(-2, -1), a_geom) |
|
|
| |
| eigenvalues, _ = LA.eigh(G, method='torch') |
|
|
| |
| magnitudes = eigenvalues.abs().sqrt() |
| spec_embed = self.spec_proj(magnitudes) |
| spec_embed = spec_embed.unsqueeze(1).expand(B, n, d) |
|
|
| |
| q_proj = self.query_proj(queries) |
| gate_input = torch.cat([q_proj, spec_embed], dim=-1) |
| g = torch.sigmoid(self.gate(gate_input)) |
| return queries + g * spec_embed |
|
|
|
|
| |
| |
| |
|
|
| class OrbitalFlow(BaseFlow): |
| """Omega-based orbital resonance flow. |
| |
| Computes the constellation's resonance frequencies (Οα΅’ = βΞ»α΅’ from |
| Gram eigendecomposition), then uses the full eigendecomposition to |
| project queries into the resonance basis, apply frequency-dependent |
| modulation, and project back. |
| |
| This flow directly uses the Ο spectrum to shape the geometric response. |
| Modes in the CV band [0.447, 0.480] (corresponding to Ξ» β [0.20, 0.23]) |
| are amplified. Modes outside are attenuated. |
| """ |
| def __init__(self, d_model: int, n_anchors: int, cv_lo: float = 0.20, cv_hi: float = 0.23): |
| super().__init__(d_model, n_anchors, name='orbital') |
| self.geom_dim = min(n_anchors, 12) |
| self.anchor_proj = nn.Linear(d_model, self.geom_dim) |
| self.cv_lo = cv_lo |
| self.cv_hi = cv_hi |
| |
| self.mode_response = nn.Parameter(torch.ones(self.geom_dim)) |
| |
| self.query_to_geom = nn.Linear(d_model, self.geom_dim) |
| self.geom_to_query = nn.Linear(self.geom_dim, d_model) |
| self.out_proj = nn.Linear(d_model, d_model) |
|
|
| def _flow(self, anchors, queries): |
| B, n, d = queries.shape |
| a_geom = self.anchor_proj(anchors) |
| G = torch.bmm(a_geom.transpose(-2, -1), a_geom) |
|
|
| |
| eigenvalues, eigenvectors = LA.eigh(G, method='torch') |
|
|
| |
| omega = eigenvalues.abs().sqrt() |
|
|
| |
| in_band = ((eigenvalues >= self.cv_lo) & (eigenvalues <= self.cv_hi)).float() |
| near_binding = torch.exp(-10.0 * (eigenvalues - 0.29154).pow(2)) |
|
|
| |
| mode_weight = self.mode_response.unsqueeze(0) * (1.0 + in_band + near_binding) |
|
|
| |
| q_geom = self.query_to_geom(queries) |
| |
| q_eigen = torch.bmm(q_geom, eigenvectors) |
|
|
| |
| q_modulated = q_eigen * mode_weight.unsqueeze(1) |
|
|
| |
| q_out = torch.bmm(q_modulated, eigenvectors.transpose(-2, -1)) |
|
|
| |
| return self.out_proj(self.geom_to_query(q_out) + queries) |
|
|
|
|
| |
| |
| |
|
|
| class AlignmentFlow(BaseFlow): |
| """SVD alignment flow via soft Procrustes rotation in projected space. |
| |
| Projects to geom_dim, computes optimal rotation via SVD of the |
| cross-covariance in the small space, applies rotation, projects back. |
| """ |
| def __init__(self, d_model: int, n_anchors: int): |
| super().__init__(d_model, n_anchors, name='alignment') |
| self.geom_dim = min(n_anchors, 12) |
| self.anchor_proj = nn.Linear(d_model, self.geom_dim) |
| self.query_proj = nn.Linear(d_model, self.geom_dim) |
| self.geom_to_query = nn.Linear(self.geom_dim, d_model) |
| self.strength = nn.Parameter(torch.tensor(0.1)) |
|
|
| def _flow(self, anchors, queries): |
| B, n, d = queries.shape |
| |
| a_proj = self.anchor_proj(anchors) |
| q_proj = self.query_proj(queries) |
|
|
| |
| sim = torch.bmm(q_proj, a_proj.transpose(-2, -1)) / math.sqrt(self.geom_dim) |
| weights = F.softmax(sim, dim=-1) |
| targets = torch.bmm(weights, a_proj) |
|
|
| |
| C = torch.bmm(q_proj.transpose(-2, -1), targets) |
|
|
| |
| U, _, Vh = LA.svd(C, method='gram_eigh') |
| R = torch.bmm(U, Vh) |
|
|
| |
| q_rotated = torch.bmm(q_proj, R) |
| delta = self.geom_to_query(q_rotated - q_proj) |
| return queries + self.strength * delta |
|
|
|
|
| |
| |
| |
|
|
| class FlowEnsemble(nn.Module): |
| """Ensemble fusion of multiple geometric flows. |
| |
| Each flow produces a prediction and a confidence score. |
| The ensemble fuses predictions weighted by confidence. |
| |
| The fusion can be: |
| 'weighted': confidence-weighted average |
| 'gated': learned gate over concatenated predictions |
| 'residual': sum of confidence-weighted residuals from input |
| """ |
| def __init__(self, flows: List[BaseFlow], d_model: int, fusion: str = 'weighted'): |
| super().__init__() |
| self.flows = nn.ModuleList(flows) |
| self.d_model = d_model |
| self.fusion = fusion |
| self.n_flows = len(flows) |
|
|
| if fusion == 'gated': |
| self.gate = nn.Sequential( |
| nn.Linear(d_model * self.n_flows, d_model), |
| nn.GELU(), |
| nn.Linear(d_model, d_model), |
| ) |
|
|
| |
| self.temperature = nn.Parameter(torch.ones(self.n_flows)) |
|
|
| def forward(self, anchors: Tensor, queries: Tensor) -> Tensor: |
| """ |
| Args: |
| anchors: [B, k, d] constellation anchors |
| queries: [B, n, d] query embeddings |
| |
| Returns: |
| fused: [B, n, d] ensemble prediction |
| """ |
| predictions = [] |
| confidences = [] |
|
|
| for i, flow in enumerate(self.flows): |
| pred, conf = flow(anchors, queries) |
| predictions.append(pred) |
| confidences.append(conf * self.temperature[i]) |
|
|
| if self.fusion == 'weighted': |
| return self._weighted_fusion(predictions, confidences) |
| elif self.fusion == 'gated': |
| return self._gated_fusion(predictions, confidences) |
| elif self.fusion == 'residual': |
| return self._residual_fusion(predictions, confidences, queries) |
| else: |
| raise ValueError(f"Unknown fusion: {self.fusion}") |
|
|
| def _weighted_fusion(self, preds, confs): |
| |
| conf_stack = torch.cat(confs, dim=-1) |
| weights = F.softmax(conf_stack, dim=-1) |
| pred_stack = torch.stack(preds, dim=-1) |
| return (pred_stack * weights.unsqueeze(-2)).sum(dim=-1) |
|
|
| def _gated_fusion(self, preds, confs): |
| cat = torch.cat(preds, dim=-1) |
| return self.gate(cat) |
|
|
| def _residual_fusion(self, preds, confs, queries): |
| conf_stack = torch.cat(confs, dim=-1) |
| weights = F.softmax(conf_stack, dim=-1) |
| residuals = torch.stack([p - queries for p in preds], dim=-1) |
| fused_residual = (residuals * weights.unsqueeze(-2)).sum(dim=-1) |
| return queries + fused_residual |
|
|
| def flow_diagnostics(self, anchors: Tensor, queries: Tensor) -> dict: |
| """Run all flows and return per-flow diagnostics.""" |
| diag = {} |
| for i, flow in enumerate(self.flows): |
| pred, conf = flow(anchors, queries) |
| diag[flow.name] = { |
| 'pred_norm': pred.norm(dim=-1).mean().item(), |
| 'confidence_mean': conf.mean().item(), |
| 'confidence_std': conf.std().item(), |
| 'residual_norm': (pred - queries).norm(dim=-1).mean().item(), |
| 'temperature': self.temperature[i].item(), |
| } |
| return diag |