| |
| |
| |
| """ |
| Mathematical Foundation & Conceptual Documentation |
| ------------------------------------------------- |
| |
| CORE PRINCIPLE: |
| Combines Graph Neural Networks with Hopfield associative memories and decision-tree-like |
| edge branching to create networks where both nodes and edges have memory capabilities. |
| Each edge can dynamically route through multiple relation hypotheses, enabling |
| adaptive graph reasoning with memory-augmented message passing. |
| |
| MATHEMATICAL FOUNDATION: |
| ======================= |
| |
| 1. HOPFIELD MEMORY MECHANICS: |
| Energy Function: E = -½ ∑ᵢⱼ wᵢⱼ sᵢ sⱼ + ∑ᵢ θᵢ sᵢ |
| |
| Where: |
| - wᵢⱼ: connection weights between memory units |
| - sᵢ, sⱼ: activation states of units i, j |
| - θᵢ: threshold for unit i |
| - E: system energy (minimized during retrieval) |
| |
| 2. ASSOCIATIVE MEMORY RETRIEVAL: |
| Content Addressing: aₜ = softmax(βₜ · K(q, M)) |
| |
| Where: |
| - q: query vector |
| - M: memory matrix [memory_slots, memory_dim] |
| - K(q,M): similarity function (typically cosine similarity) |
| - βₜ: temperature parameter (attention sharpness) |
| - aₜ: attention weights over memory slots |
| |
| 3. DECISION GATE BRANCHING: |
| Branch Weights: w = softmax(EdgeScorer(concat(xᵢ, xⱼ))/τ) |
| |
| Where: |
| - xᵢ, xⱼ: node features for edge (i,j) |
| - EdgeScorer: neural network mapping edge features to K branch logits |
| - τ: temperature parameter |
| - w: simplex over K relation hypotheses |
| |
| 4. MESSAGE PASSING WITH MEMORY: |
| Node Update: hᵢ⁽ˡ⁺¹⁾ = NodeMemory(hᵢ⁽ˡ⁾ + ∑ⱼ∈N(i) Aᵢⱼ · EdgeMemory(hᵢ⁽ˡ⁾, hⱼ⁽ˡ⁾)) |
| |
| Where: |
| - hᵢ⁽ˡ⁾: node representation at layer l |
| - N(i): neighbors of node i |
| - Aᵢⱼ: attention-weighted adjacency (after decision branching) |
| - NodeMemory, EdgeMemory: Hopfield memory modules |
| |
| 5. BARYCENTRIC EDGE MERGING: |
| A'ᵢⱼ = (Aᵢⱼ > 0) · mean_k(wᵢⱼₖ) |
| |
| Where: |
| - A'ᵢⱼ: merged edge weight |
| - wᵢⱼₖ: weight for branch k on edge (i,j) |
| - Keeps edge weights in convex hull of branch hypotheses |
| |
| 6. ENERGY MINIMIZATION: |
| ∂E/∂h = -∂H/∂h where H is Hopfield energy |
| |
| Memory retrieval follows gradient descent on energy landscape. |
| |
| CONCEPTUAL REASONING: |
| ==================== |
| |
| WHY HOPFIELD + GRAPHS + DECISION BRANCHING? |
| - Standard GNNs assume fixed edge semantics |
| - Real-world graphs have ambiguous, multi-faceted relationships |
| - Hopfield memories provide content-addressable associative recall |
| - Decision branching enables soft routing through relation types |
| - Memory-augmented edges learn context-dependent message passing |
| |
| KEY INNOVATIONS: |
| 1. **Dual Memory Architecture**: Both nodes and edges have associative memories |
| 2. **Decision-Tree Edge Routing**: Soft branching through K relation hypotheses |
| 3. **Hard/Soft Routing Modes**: Deterministic routing during evaluation |
| 4. **Energy-Based Retrieval**: Hopfield dynamics for memory access |
| 5. **Hierarchical Message Passing**: Memory → branching → aggregation |
| |
| APPLICATIONS: |
| - Knowledge graph reasoning with uncertain relations |
| - Social network analysis with multi-type interactions |
| - Molecular property prediction with bond ambiguity |
| - Recommendation systems with multi-faceted user-item relations |
| - Program analysis with context-dependent variable relationships |
| |
| COMPLEXITY ANALYSIS: |
| - Node Memory: O(N · D · S) where N=nodes, D=features, S=memory_slots |
| - Edge Memory: O(E · D · S) where E=edges |
| - Decision Branching: O(E · K) where K=branch_count |
| - Message Passing: O(E · D + N · D²) per layer |
| - Memory: O((N+E) · S · D) for stored patterns |
| |
| BIOLOGICAL INSPIRATION: |
| - Hippocampal pattern completion and separation |
| - Cortical associative memory networks |
| - Synaptic plasticity and connection strength modulation |
| - Neural circuit motifs with context-dependent routing |
| - Memory consolidation through repeated activation patterns |
| """ |
|
|
| from __future__ import annotations |
| import logging |
| from dataclasses import dataclass |
| from typing import Dict, Optional, Protocol, Tuple, TypedDict |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing_extensions import Self |
|
|
| |
| logger = logging.getLogger(__name__) |
| if not logger.handlers: |
| _h = logging.StreamHandler() |
| _f = logging.Formatter("%(asctime)s | %(name)s | %(levelname)s | %(message)s") |
| _h.setFormatter(_f) |
| logger.addHandler(_h) |
| logger.setLevel(logging.INFO) |
|
|
| |
|
|
| class GraphShapeError(ValueError): |
| """Raised when provided tensors do not match expected graph shapes.""" |
|
|
| class RoutingError(RuntimeError): |
| """Raised when routing/branching fails due to numerical or config issues.""" |
|
|
| class AuxOut(TypedDict): |
| """Auxiliary outputs returned by the model.""" |
| branch_weights: torch.Tensor |
| hopfield_node_energy: torch.Tensor |
| hopfield_edge_energy: torch.Tensor |
|
|
| |
| |
|
|
| class MergeStrategy(Protocol): |
| """Protocol for merging K branch-specific edge weights into an adjacency.""" |
|
|
| def __call__(self, base_adj: torch.Tensor, branch_weights: torch.Tensor) -> torch.Tensor: |
| """Merge branch weights into an augmented adjacency. |
| |
| Args: |
| base_adj: (B, N, N) original adjacency (0/1 or weighted). |
| branch_weights: (B, N, N, K) simplex per edge over K branches. |
| |
| Returns: |
| (B, N, N) merged adjacency. |
| """ |
|
|
| def barycentric_merge(base_adj: torch.Tensor, branch_weights: torch.Tensor) -> torch.Tensor: |
| """Barycentric merge of branches into a single weighted adjacency. |
| |
| This keeps edge weights in a convex hull of branch hypotheses and the base |
| adjacency. It's simple, stable, and differentiable. |
| |
| Mathematical Details: |
| - Takes mean of branch weights: w̄ = (1/K) Σₖ wₖ |
| - Applies to existing edges: A'ᵢⱼ = (Aᵢⱼ > 0) · w̄ᵢⱼ |
| - Preserves graph structure while allowing soft edge weights |
| |
| Args: |
| base_adj: (B, N, N) - Original adjacency matrix |
| branch_weights: (B, N, N, K) - Branch weight simplex |
| |
| Returns: |
| (B, N, N) - Merged adjacency with barycentric edge weights |
| """ |
| bw = branch_weights.mean(dim=-1) |
| return (base_adj > 0).to(base_adj.dtype) * bw |
|
|
| |
| |
|
|
| class HopfieldMemory(nn.Module): |
| """Hopfield associative memory with content-based retrieval. |
| |
| Implements a modern Hopfield network that stores patterns in key-value |
| memory slots and retrieves them via content-based attention. Uses |
| temperature-controlled softmax for retrieval sharpness. |
| |
| Mathematical Framework: |
| - Keys: K ∈ ℝˢˣᴰ (memory slots × feature dimension) |
| - Values: V ∈ ℝˢˣᴰ (stored pattern values) |
| - Query: q ∈ ℝᴰ (input pattern for retrieval) |
| - Attention: α = softmax(q^T K / √(D·scale)) |
| - Output: o = α^T V (weighted combination of stored patterns) |
| |
| The "energy" proxy measures retrieval sharpness (low entropy = high energy). |
| """ |
|
|
| def __init__(self, dim: int, mem_slots: int = 64, key_scale: float = 1.0) -> None: |
| super().__init__() |
| self.dim = dim |
| self.mem_slots = mem_slots |
| self.key_scale = float(key_scale) |
|
|
| |
| self.keys = nn.Parameter(torch.randn(mem_slots, dim) * (1.0 / np.sqrt(dim))) |
| self.vals = nn.Parameter(torch.randn(mem_slots, dim) * (1.0 / np.sqrt(dim))) |
| |
| |
| self.proj_q = nn.Linear(dim, dim, bias=False) |
| self.proj_o = nn.Linear(dim, dim, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Retrieve patterns from associative memory via content addressing. |
| |
| Implements the core Hopfield retrieval mechanism: |
| 1. Project input to query space |
| 2. Compute content-based attention over memory keys |
| 3. Retrieve weighted combination of stored values |
| 4. Project output and compute energy proxy |
| |
| Mathematical Details: |
| - Query projection: q = W_q · x |
| - Similarity: sim = q^T K_i for each memory slot i |
| - Attention: α_i = exp(sim_i/τ) / Σⱼ exp(sim_j/τ) |
| - Retrieval: r = Σᵢ α_i · V_i |
| - Output: o = W_o · r |
| |
| Args: |
| x: Input tensor (..., D) - query patterns |
| |
| Returns: |
| Tuple of: |
| - out: Retrieved patterns (..., D) |
| - energy: Scalar energy proxy (higher = more focused retrieval) |
| """ |
| |
| q = self.proj_q(x) |
| |
| |
| scale = np.sqrt(self.dim) * max(self.key_scale, 1e-6) |
| attn = F.softmax((q @ self.keys.T) / scale, dim=-1) |
| |
| |
| retrieved = attn @ self.vals |
| out = self.proj_o(retrieved) |
| |
| |
| p = attn.clamp_min(1e-9) |
| entropy = -(p * p.log()).sum(dim=-1).mean() |
| energy = -entropy |
| |
| return out, energy |
|
|
| |
| |
|
|
| class DecisionGate(nn.Module): |
| """Learnable branching gate for multi-hypothesis edge routing. |
| |
| Implements soft decision tree routing where each edge can branch through |
| K different relation hypotheses. Uses pairwise node features to predict |
| branch probabilities, enabling context-dependent edge semantics. |
| |
| Mathematical Framework: |
| - Edge features: eᵢⱼ = concat(xᵢ, xⱼ) ∈ ℝ²ᴰ |
| - Branch logits: lᵢⱼ = MLP(eᵢⱼ) ∈ ℝᴷ |
| - Branch weights: wᵢⱼ = softmax(lᵢⱼ/τ) ∈ Δᴷ⁻¹ |
| |
| Supports both soft routing (training) and hard routing (evaluation) |
| for computational efficiency and interpretability. |
| """ |
|
|
| def __init__(self, dim: int, branches: int = 4, temperature: float = 0.7, hard_eval: bool = True) -> None: |
| super().__init__() |
| if branches < 1: |
| raise ValueError("branches must be >= 1") |
| self.dim = dim |
| self.K = branches |
| self.temperature = float(temperature) |
| self.hard_eval = bool(hard_eval) |
| |
| |
| self.edge_scorer = nn.Sequential( |
| nn.Linear(2 * dim, dim), |
| nn.GELU(), |
| nn.Linear(dim, self.K), |
| ) |
|
|
| @staticmethod |
| def _pairwise_concat(x: torch.Tensor) -> torch.Tensor: |
| """Create pairwise concatenations for all edge combinations. |
| |
| Generates edge feature representations by concatenating all pairs |
| of node features, creating a complete edge feature tensor. |
| |
| Mathematical Details: |
| - For nodes X = [x₁, x₂, ..., xₙ] ∈ ℝᴺˣᴰ |
| - Create eᵢⱼ = [xᵢ; xⱼ] ∈ ℝ²ᴰ for all pairs (i,j) |
| - Result: E ∈ ℝᴺˣᴺˣ²ᴰ |
| |
| Args: |
| x: Node features (B, N, D) |
| |
| Returns: |
| Edge features (B, N, N, 2D) - concatenated pairwise features |
| """ |
| B, N, D = x.shape |
| xi = x.unsqueeze(2).expand(B, N, N, D) |
| xj = x.unsqueeze(1).expand(B, N, N, D) |
| return torch.cat([xi, xj], dim=-1) |
|
|
| def forward(self, x: torch.Tensor, mask_adj: torch.Tensor) -> torch.Tensor: |
| """Compute branch routing probabilities for each edge. |
| |
| Determines how each edge should route through K relation hypotheses |
| based on the features of its endpoint nodes. Masked edges receive |
| zero routing weights. |
| |
| Mathematical Process: |
| 1. Create pairwise edge features eᵢⱼ = [xᵢ; xⱼ] |
| 2. Compute branch logits lᵢⱼ = EdgeScorer(eᵢⱼ) |
| 3. Apply temperature and softmax: wᵢⱼ = softmax(lᵢⱼ/τ) |
| 4. Mask non-edges and renormalize |
| |
| Args: |
| x: Node features (B, N, D) |
| mask_adj: Edge mask (B, N, N) - 1 for valid edges, 0 for non-edges |
| |
| Returns: |
| Branch weights (B, N, N, K) - simplex per edge over K branches |
| """ |
| if x.dim() != 3: |
| raise GraphShapeError("x must be (B,N,D)") |
| if mask_adj.dim() != 3: |
| raise GraphShapeError("mask_adj must be (B,N,N)") |
| B, N, D = x.shape |
| if mask_adj.shape[:2] != (B, N) or mask_adj.shape[2] != N: |
| raise GraphShapeError("mask_adj shape mismatch") |
|
|
| |
| edge_feats = self._pairwise_concat(x) |
| |
| |
| logits = self.edge_scorer(edge_feats) |
| |
| |
| temp = max(self.temperature, 1e-5) |
| if self.training: |
| |
| weights = F.softmax(logits / temp, dim=-1) |
| else: |
| |
| w = F.softmax(logits / temp, dim=-1) |
| if self.hard_eval: |
| |
| idx = w.argmax(dim=-1, keepdim=True) |
| hard = torch.zeros_like(w).scatter_(-1, idx, 1.0) |
| weights = hard |
| else: |
| weights = w |
|
|
| |
| weights = weights * mask_adj.unsqueeze(-1) |
| sums = weights.sum(dim=-1, keepdim=True) |
| weights = torch.where(sums > 0, weights / sums.clamp_min(1e-9), weights) |
| |
| return weights |
|
|
| |
| |
|
|
| class HopfieldDecisionLayer(nn.Module): |
| """Graph neural network layer with Hopfield memories and decision branching. |
| |
| Integrates three key components: |
| 1. Node-level Hopfield memory for pattern completion |
| 2. Edge-level Hopfield memory for relation-specific message encoding |
| 3. Decision gate for soft routing through relation hypotheses |
| |
| This creates a powerful message-passing framework where both nodes and |
| edges have associative memory capabilities, and edge semantics can |
| dynamically adapt based on context. |
| |
| Mathematical Flow: |
| 1. Node memory retrieval: h'ᵢ = HopfieldNode(hᵢ) |
| 2. Edge message encoding: mᵢⱼ = HopfieldEdge(hᵢ - hⱼ) |
| 3. Decision branching: wᵢⱼ = DecisionGate(hᵢ, hⱼ) |
| 4. Message aggregation: m̄ᵢ = Σⱼ A'ᵢⱼ · mᵢⱼ |
| 5. Node update: hᵢ⁽ˡ⁺¹⁾ = LayerNorm(hᵢ + MLP([h'ᵢ; m̄ᵢ])) |
| """ |
|
|
| def __init__( |
| self, |
| dim: int, |
| mem_slots_nodes: int = 64, |
| mem_slots_edges: int = 32, |
| branches: int = 4, |
| temperature: float = 0.7, |
| hard_eval: bool = True, |
| merge: Optional[MergeStrategy] = None, |
| ) -> None: |
| super().__init__() |
| |
| |
| self.node_mem = HopfieldMemory(dim, mem_slots_nodes) |
| self.edge_mem = HopfieldMemory(dim, mem_slots_edges) |
| self.gate = DecisionGate(dim, branches, temperature, hard_eval) |
| self.merge = merge or barycentric_merge |
|
|
| |
| self.msg_mlp = nn.Sequential( |
| nn.Linear(dim, dim), |
| nn.GELU(), |
| nn.Linear(dim, dim), |
| ) |
| |
| |
| self.node_mlp = nn.Sequential( |
| nn.Linear(2 * dim, dim), |
| nn.GELU(), |
| nn.Linear(dim, dim), |
| ) |
| |
| self.norm = nn.LayerNorm(dim) |
|
|
| def forward(self, x: torch.Tensor, A: torch.Tensor) -> Tuple[torch.Tensor, AuxOut]: |
| """Execute one layer of memory-augmented graph message passing. |
| |
| Implements the complete forward pass combining Hopfield memory |
| retrieval, decision branching, and message aggregation. |
| |
| Mathematical Algorithm: |
| 1. Retrieve node patterns: h'ᵢ = NodeMemory(hᵢ) |
| 2. Encode edge messages: mᵢⱼ = EdgeMemory(hᵢ - hⱼ) |
| 3. Compute branch routing: wᵢⱼ = DecisionGate(hᵢ, hⱼ) |
| 4. Merge adjacency: A' = MergeStrategy(A, w) |
| 5. Aggregate messages: m̄ᵢ = Σⱼ A'ᵢⱼ · MLP(mᵢⱼ) / deg(i) |
| 6. Update nodes: hᵢ⁽ˡ⁺¹⁾ = LayerNorm(hᵢ + MLP([h'ᵢ; m̄ᵢ])) |
| |
| Args: |
| x: Node features (B, N, D) |
| A: Adjacency matrix (B, N, N) - 0/1 or weighted |
| |
| Returns: |
| Tuple of: |
| - x_next: Updated node features (B, N, D) |
| - aux: Auxiliary outputs (branch weights, energy values) |
| """ |
| if x.dim() != 3: |
| raise GraphShapeError("x must be (B,N,D)") |
| if A.dim() != 3: |
| raise GraphShapeError("A must be (B,N,N)") |
| B, N, D = x.shape |
| if A.shape != (B, N, N): |
| raise GraphShapeError("A shape mismatch with x") |
|
|
| A = A.to(x.dtype) |
|
|
| |
| node_retrieved, node_energy = self.node_mem(x) |
|
|
| |
| |
| xi = x.unsqueeze(2).expand(B, N, N, D) |
| xj = x.unsqueeze(1).expand(B, N, N, D) |
| edge_repr = xi - xj |
| |
| edge_mem_out, edge_energy = self.edge_mem(edge_repr) |
|
|
| |
| branch_w = self.gate(x, (A > 0).to(A.dtype)) |
| A_aug = self.merge(A, branch_w).clamp_min(0.0) |
|
|
| |
| msg = self.msg_mlp(edge_mem_out) |
| agg = torch.einsum("bij,bijd->bid", A_aug, msg) |
| |
| |
| deg = A_aug.sum(dim=-1, keepdim=True).clamp_min(1e-9) |
| msg_norm = agg / deg |
|
|
| |
| x_cat = torch.cat([x + node_retrieved, msg_norm], dim=-1) |
| x_update = self.node_mlp(x_cat) |
| x_next = self.norm(x + x_update) |
|
|
| |
| aux: AuxOut = { |
| "branch_weights": branch_w, |
| "hopfield_node_energy": node_energy.detach().clone(), |
| "hopfield_edge_energy": edge_energy.detach().clone(), |
| } |
| |
| return x_next, aux |
|
|
| |
| |
|
|
| @dataclass |
| class HopfieldDecisionGNNConfig: |
| """Configuration for the complete Hopfield Decision GNN model.""" |
| dim: int |
| layers: int = 3 |
| mem_slots_nodes: int = 64 |
| mem_slots_edges: int = 32 |
| branches: int = 4 |
| temperature: float = 0.7 |
| hard_eval: bool = True |
|
|
| class HopfieldDecisionGNN(nn.Module): |
| """Complete Hopfield Decision GNN with stacked layers and global adaptation. |
| |
| Implements a multi-layer graph neural network where each layer combines: |
| - Hopfield associative memories for nodes and edges |
| - Decision-tree-like branching for edge relation types |
| - Adaptive message passing with memory retrieval |
| |
| The model learns to store and retrieve graph patterns while dynamically |
| routing messages through different relation hypotheses based on context. |
| |
| Architecture: |
| - Multiple HopfieldDecisionLayers stacked sequentially |
| - Global adaptation mechanism for model-wide learning |
| - Residual readout combining raw and processed representations |
| """ |
|
|
| def __init__(self, cfg: HopfieldDecisionGNNConfig) -> None: |
| super().__init__() |
| self.cfg = cfg |
| |
| |
| self.layers = nn.ModuleList([ |
| HopfieldDecisionLayer( |
| dim=cfg.dim, |
| mem_slots_nodes=cfg.mem_slots_nodes, |
| mem_slots_edges=cfg.mem_slots_edges, |
| branches=cfg.branches, |
| temperature=cfg.temperature, |
| hard_eval=cfg.hard_eval, |
| ) |
| for _ in range(cfg.layers) |
| ]) |
| |
| |
| self.readout = nn.Sequential( |
| nn.Linear(cfg.dim, cfg.dim), |
| nn.GELU(), |
| nn.Linear(cfg.dim, cfg.dim), |
| ) |
|
|
| def forward(self, x: torch.Tensor, A: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
| """Forward pass through the complete stacked model. |
| |
| Processes graphs through multiple layers of memory-augmented message |
| passing, collecting energy statistics and routing information across |
| all layers for analysis and optimization. |
| |
| Mathematical Flow: |
| 1. Initialize with input node features |
| 2. For each layer l = 1..L: |
| - Apply Hopfield memory retrieval |
| - Compute decision branching weights |
| - Perform message passing with routing |
| - Update node representations |
| 3. Apply final readout with residual connection |
| 4. Aggregate auxiliary statistics across layers |
| |
| Args: |
| x: Node features (B, N, D) |
| A: Adjacency matrix (B, N, N) |
| |
| Returns: |
| Tuple of: |
| - y: Final node representations (B, N, D) |
| - aux_all: Aggregated auxiliary information across layers |
| """ |
| energies_node = [] |
| energies_edge = [] |
| last_br = None |
| |
| |
| for i, layer in enumerate(self.layers): |
| x, aux = layer(x, A) |
| |
| |
| energies_node.append(aux["hopfield_node_energy"]) |
| energies_edge.append(aux["hopfield_edge_energy"]) |
| last_br = aux["branch_weights"] |
| |
| |
| logger.debug( |
| "Layer %d: node_energy=%.5f edge_energy=%.5f", |
| i, |
| float(aux["hopfield_node_energy"]), |
| float(aux["hopfield_edge_energy"]), |
| ) |
|
|
| |
| y = self.readout(x) + x |
| |
| |
| aux_all: Dict[str, torch.Tensor] = { |
| "node_energy_mean": torch.stack(energies_node).mean(), |
| "edge_energy_mean": torch.stack(energies_edge).mean(), |
| } |
| if last_br is not None: |
| aux_all["branch_weights_last"] = last_br |
| |
| return y, aux_all |
|
|
| |
| |
|
|
| def _make_batch(B: int, N: int, D: int, p_edge: float = 0.5) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Generate random batch of graphs for testing.""" |
| torch.manual_seed(42) |
| x = torch.randn(B, N, D) |
| A = (torch.rand(B, N, N) < p_edge).float() |
| |
| eye = torch.eye(N).unsqueeze(0) |
| A = A * (1 - eye) |
| return x, A |
|
|
| def test_shapes_and_aux() -> None: |
| """Test basic functionality and output shapes.""" |
| cfg = HopfieldDecisionGNNConfig(dim=32, layers=2, branches=3) |
| model = HopfieldDecisionGNN(cfg) |
| x, A = _make_batch(B=2, N=5, D=32) |
| y, aux = model(x, A) |
| |
| assert y.shape == x.shape, f"y shape {y.shape} != x {x.shape}" |
| assert "node_energy_mean" in aux and "edge_energy_mean" in aux |
| assert "branch_weights_last" in aux |
| |
| bw = aux["branch_weights_last"] |
| assert bw.shape == (2, 5, 5, 3) |
| |
| |
| s = bw.sum(dim=-1) |
| masked = (A == 0) |
| assert torch.allclose(s[~masked], torch.ones_like(s[~masked]), atol=1e-5) |
| print("[PASS] shapes_and_aux") |
|
|
| def test_gate_eval_hard_routing() -> None: |
| """Test hard routing behavior during evaluation.""" |
| cfg = HopfieldDecisionGNNConfig(dim=16, layers=1, branches=4, hard_eval=True) |
| gate = DecisionGate(dim=16, branches=4, temperature=0.5, hard_eval=True) |
| x, A = _make_batch(B=1, N=4, D=16) |
| |
| gate.eval() |
| with torch.no_grad(): |
| w = gate(x, A) |
| |
| assert (w.sum(dim=-1) - (A > 0).float()).abs().max() < 1e-5 |
| |
| |
| on_edges = (A[0] > 0) |
| if on_edges.any(): |
| sub = w[0][on_edges] |
| assert torch.allclose(sub.max(dim=-1).values, torch.ones_like(sub[..., 0])) |
| print("[PASS] gate_eval_hard_routing") |
|
|
| def test_gradient_flow() -> None: |
| """Test that gradients flow through the model.""" |
| cfg = HopfieldDecisionGNNConfig(dim=24, layers=3) |
| model = HopfieldDecisionGNN(cfg) |
| x, A = _make_batch(B=3, N=6, D=24) |
| y, aux = model(x, A) |
| |
| |
| loss = (y ** 2).mean() + aux["node_energy_mean"] * 0.01 + aux["edge_energy_mean"] * 0.01 |
| loss.backward() |
| |
| |
| grads = [p.grad is not None and p.grad.abs().sum().item() > 0 for p in model.parameters()] |
| assert any(grads), "No gradients found" |
| print("[PASS] gradient_flow") |
|
|
| def test_batching_invariance() -> None: |
| """Test that batching doesn't affect individual graph processing.""" |
| cfg = HopfieldDecisionGNNConfig(dim=12, layers=2) |
| model = HopfieldDecisionGNN(cfg) |
| x1, A1 = _make_batch(B=1, N=5, D=12) |
| x2, A2 = _make_batch(B=1, N=5, D=12) |
| |
| |
| y1, _ = model(x1, A1) |
| y2, _ = model(x2, A2) |
| |
| |
| y_cat, _ = model(torch.cat([x1, x2], dim=0), torch.cat([A1, A2], dim=0)) |
| |
| assert torch.allclose(y1, y_cat[:1], atol=1e-5) |
| assert torch.allclose(y2, y_cat[1:], atol=1e-5) |
| print("[PASS] batching_invariance") |
|
|
| def test_shape_errors() -> None: |
| """Test that appropriate errors are raised for invalid inputs.""" |
| cfg = HopfieldDecisionGNNConfig(dim=8, layers=1) |
| model = HopfieldDecisionGNN(cfg) |
| x, A = _make_batch(B=2, N=4, D=8) |
|
|
| |
| try: |
| _ = model(x[0], A) |
| raise AssertionError("Expected GraphShapeError not raised") |
| except GraphShapeError: |
| pass |
|
|
| |
| try: |
| _ = model(x, A[0]) |
| raise AssertionError("Expected GraphShapeError not raised") |
| except GraphShapeError: |
| pass |
|
|
| |
| try: |
| _ = model(x, A[:, :3, :3]) |
| raise AssertionError("Expected GraphShapeError not raised") |
| except GraphShapeError: |
| pass |
|
|
| print("[PASS] shape_errors") |
|
|
| def test_hopfield_decision_graph(): |
| """Comprehensive test of Hopfield Decision Graph functionality.""" |
| print("Testing Hopfield Decision Graph - Memory-Augmented Graph Neural Networks") |
| print("=" * 85) |
| |
| |
| cfg = HopfieldDecisionGNNConfig( |
| dim=64, |
| layers=4, |
| mem_slots_nodes=32, |
| mem_slots_edges=16, |
| branches=3, |
| temperature=0.8, |
| hard_eval=True |
| ) |
| |
| model = HopfieldDecisionGNN(cfg) |
| |
| print(f"Created Hopfield Decision GNN:") |
| print(f" - Feature dimension: {cfg.dim}") |
| print(f" - Number of layers: {cfg.layers}") |
| print(f" - Node memory slots: {cfg.mem_slots_nodes}") |
| print(f" - Edge memory slots: {cfg.mem_slots_edges}") |
| print(f" - Decision branches: {cfg.branches}") |
| |
| |
| total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f" - Total parameters: {total_params:,}") |
| |
| |
| batch_size, num_nodes = 8, 12 |
| x, A = _make_batch(batch_size, num_nodes, cfg.dim, p_edge=0.3) |
| |
| print(f"\nTesting with graphs:") |
| print(f" - Batch size: {batch_size}") |
| print(f" - Nodes per graph: {num_nodes}") |
| print(f" - Edge density: ~30%") |
| print(f" - Total edges: {A.sum().item():.0f}") |
| |
| |
| print(f"\n Executing forward pass...") |
| y, aux = model(x, A) |
| |
| print(f"Forward pass results:") |
| print(f" - Output shape: {y.shape}") |
| print(f" - Node energy: {aux['node_energy_mean']:.4f}") |
| print(f" - Edge energy: {aux['edge_energy_mean']:.4f}") |
| |
| |
| if 'branch_weights_last' in aux: |
| branch_weights = aux['branch_weights_last'] |
| print(f"\nDecision branching analysis:") |
| print(f" - Branch weights shape: {branch_weights.shape}") |
| |
| |
| branch_entropy = -(branch_weights * torch.log(branch_weights + 1e-8)).sum(dim=-1) |
| avg_entropy = branch_entropy[A > 0].mean().item() |
| max_entropy = math.log(cfg.branches) |
| branching_diversity = avg_entropy / max_entropy |
| |
| print(f" - Average branching entropy: {avg_entropy:.3f}") |
| print(f" - Branching diversity: {branching_diversity:.1%}") |
| |
| |
| most_common_branches = branch_weights.argmax(dim=-1) |
| for b in range(cfg.branches): |
| count = (most_common_branches == b).sum().item() |
| total_edges = (A > 0).sum().item() |
| pct = count / max(total_edges, 1) * 100 |
| print(f" - Branch {b}: {count} edges ({pct:.1f}%)") |
| |
| |
| print(f"\n Testing memory components...") |
| |
| |
| test_memory = HopfieldMemory(cfg.dim, mem_slots=16) |
| test_input = torch.randn(4, cfg.dim) |
| retrieved, energy = test_memory(test_input) |
| |
| print(f" - Memory retrieval shape: {retrieved.shape}") |
| print(f" - Memory energy: {energy:.4f}") |
| |
| |
| test_gate = DecisionGate(cfg.dim, branches=cfg.branches) |
| test_nodes = torch.randn(2, 6, cfg.dim) |
| test_adj = torch.randint(0, 2, (2, 6, 6)).float() |
| gate_weights = test_gate(test_nodes, test_adj) |
| |
| print(f" - Gate output shape: {gate_weights.shape}") |
| print(f" - Gate simplex check: {torch.allclose(gate_weights.sum(-1), test_adj, atol=1e-5)}") |
| |
| |
| print(f"\n Testing structural adaptivity...") |
| |
| |
| dense_A = torch.ones(1, num_nodes, num_nodes) - torch.eye(num_nodes).unsqueeze(0) |
| dense_x = torch.randn(1, num_nodes, cfg.dim) |
| dense_y, dense_aux = model(dense_x, dense_A) |
| |
| |
| sparse_A = torch.zeros(1, num_nodes, num_nodes) |
| sparse_A[0, 0, 1] = sparse_A[0, 1, 2] = sparse_A[0, 2, 0] = 1 |
| sparse_x = torch.randn(1, num_nodes, cfg.dim) |
| sparse_y, sparse_aux = model(sparse_x, sparse_A) |
| |
| print(f" - Dense graph node energy: {dense_aux['node_energy_mean']:.4f}") |
| print(f" - Sparse graph node energy: {sparse_aux['node_energy_mean']:.4f}") |
| print(f" - Dense graph edge energy: {dense_aux['edge_energy_mean']:.4f}") |
| print(f" - Sparse graph edge energy: {sparse_aux['edge_energy_mean']:.4f}") |
| |
| |
| print(f"\n Testing evaluation mode...") |
| model.eval() |
| with torch.no_grad(): |
| eval_y, eval_aux = model(x[:2], A[:2]) |
| |
| if 'branch_weights_last' in eval_aux: |
| eval_weights = eval_aux['branch_weights_last'] |
| |
| max_vals = eval_weights.max(dim=-1)[0] |
| edges_mask = (A[:2] > 0) |
| hard_routing_check = torch.allclose(max_vals[edges_mask], torch.ones_like(max_vals[edges_mask])) |
| print(f" - Hard routing active: {hard_routing_check}") |
| |
| model.train() |
| |
| print(f"\n Hopfield Decision Graph test completed!") |
| print("✓ Dual memory architecture (nodes + edges)") |
| print("✓ Decision-tree edge routing with soft branching") |
| print("✓ Energy-based associative memory retrieval") |
| print("✓ Hard/soft routing modes for training/evaluation") |
| print("✓ Memory-augmented graph message passing") |
| print("✓ Adaptive edge semantics based on node context") |
| |
| return True |
|
|
| def memory_pattern_demo(): |
| """Demonstrate memory pattern storage and retrieval.""" |
| print("\n" + "="*60) |
| print(" MEMORY PATTERN DEMONSTRATION") |
| print("="*60) |
| |
| |
| memory = HopfieldMemory(dim=8, mem_slots=4) |
| |
| |
| with torch.no_grad(): |
| memory.keys[0] = torch.tensor([1, 0, 1, 0, 1, 0, 1, 0], dtype=torch.float32) |
| memory.keys[1] = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1], dtype=torch.float32) |
| memory.keys[2] = torch.tensor([1, 1, 0, 0, 1, 1, 0, 0], dtype=torch.float32) |
| memory.keys[3] = torch.tensor([0, 0, 1, 1, 0, 0, 1, 1], dtype=torch.float32) |
| |
| memory.vals.copy_(memory.keys) |
| |
| |
| print("Testing pattern completion:") |
| test_patterns = [ |
| torch.tensor([1, 0, 1, 0, 0.5, 0, 0.8, 0], dtype=torch.float32), |
| torch.tensor([0, 1, 0, 1, 0.2, 1, 0, 0.9], dtype=torch.float32), |
| torch.tensor([1, 1, 0, 0, 0.7, 0.8, 0, 0], dtype=torch.float32), |
| ] |
| |
| for i, noisy_pattern in enumerate(test_patterns): |
| retrieved, energy = memory(noisy_pattern.unsqueeze(0)) |
| retrieved = retrieved.squeeze(0) |
| |
| print(f"\n Test {i+1}:") |
| print(f" Input: {noisy_pattern.numpy()}") |
| print(f" Retrieved: {retrieved.detach().numpy()}") |
| print(f" Energy: {energy.item():.3f}") |
| |
| print("\n Memory demonstrates pattern completion and associative recall!") |
| print(" Noisy inputs are cleaned up to stored prototype patterns") |
|
|
| if __name__ == "__main__": |
| torch.set_float32_matmul_precision("high") |
| test_shapes_and_aux() |
| test_gate_eval_hard_routing() |
| test_gradient_flow() |
| test_batching_invariance() |
| test_shape_errors() |
| test_hopfield_decision_graph() |
| memory_pattern_demo() |
| print("\nAll tests passed") |
|
|
| |
| |
|
|