hopfield_gnn / hopfield_decision_graph_docs.py
1990two's picture
Update hopfield_decision_graph_docs.py
91336e9 verified
###########################################################################################################################################
#||- - - |8.19.2025| - - - || HOPFIELD DECISION GRAPH || - - - | 1990two | - - -||#
###########################################################################################################################################
"""
Mathematical Foundation & Conceptual Documentation
-------------------------------------------------
CORE PRINCIPLE:
Combines Graph Neural Networks with Hopfield associative memories and decision-tree-like
edge branching to create networks where both nodes and edges have memory capabilities.
Each edge can dynamically route through multiple relation hypotheses, enabling
adaptive graph reasoning with memory-augmented message passing.
MATHEMATICAL FOUNDATION:
=======================
1. HOPFIELD MEMORY MECHANICS:
Energy Function: E = -½ ∑ᵢⱼ wᵢⱼ sᵢ sⱼ + ∑ᵢ θᵢ sᵢ
Where:
- wᵢⱼ: connection weights between memory units
- sᵢ, sⱼ: activation states of units i, j
- θᵢ: threshold for unit i
- E: system energy (minimized during retrieval)
2. ASSOCIATIVE MEMORY RETRIEVAL:
Content Addressing: aₜ = softmax(βₜ · K(q, M))
Where:
- q: query vector
- M: memory matrix [memory_slots, memory_dim]
- K(q,M): similarity function (typically cosine similarity)
- βₜ: temperature parameter (attention sharpness)
- aₜ: attention weights over memory slots
3. DECISION GATE BRANCHING:
Branch Weights: w = softmax(EdgeScorer(concat(xᵢ, xⱼ))/τ)
Where:
- xᵢ, xⱼ: node features for edge (i,j)
- EdgeScorer: neural network mapping edge features to K branch logits
- τ: temperature parameter
- w: simplex over K relation hypotheses
4. MESSAGE PASSING WITH MEMORY:
Node Update: hᵢ⁽ˡ⁺¹⁾ = NodeMemory(hᵢ⁽ˡ⁾ + ∑ⱼ∈N(i) Aᵢⱼ · EdgeMemory(hᵢ⁽ˡ⁾, hⱼ⁽ˡ⁾))
Where:
- hᵢ⁽ˡ⁾: node representation at layer l
- N(i): neighbors of node i
- Aᵢⱼ: attention-weighted adjacency (after decision branching)
- NodeMemory, EdgeMemory: Hopfield memory modules
5. BARYCENTRIC EDGE MERGING:
A'ᵢⱼ = (Aᵢⱼ > 0) · mean_k(wᵢⱼₖ)
Where:
- A'ᵢⱼ: merged edge weight
- wᵢⱼₖ: weight for branch k on edge (i,j)
- Keeps edge weights in convex hull of branch hypotheses
6. ENERGY MINIMIZATION:
∂E/∂h = -∂H/∂h where H is Hopfield energy
Memory retrieval follows gradient descent on energy landscape.
CONCEPTUAL REASONING:
====================
WHY HOPFIELD + GRAPHS + DECISION BRANCHING?
- Standard GNNs assume fixed edge semantics
- Real-world graphs have ambiguous, multi-faceted relationships
- Hopfield memories provide content-addressable associative recall
- Decision branching enables soft routing through relation types
- Memory-augmented edges learn context-dependent message passing
KEY INNOVATIONS:
1. **Dual Memory Architecture**: Both nodes and edges have associative memories
2. **Decision-Tree Edge Routing**: Soft branching through K relation hypotheses
3. **Hard/Soft Routing Modes**: Deterministic routing during evaluation
4. **Energy-Based Retrieval**: Hopfield dynamics for memory access
5. **Hierarchical Message Passing**: Memory → branching → aggregation
APPLICATIONS:
- Knowledge graph reasoning with uncertain relations
- Social network analysis with multi-type interactions
- Molecular property prediction with bond ambiguity
- Recommendation systems with multi-faceted user-item relations
- Program analysis with context-dependent variable relationships
COMPLEXITY ANALYSIS:
- Node Memory: O(N · D · S) where N=nodes, D=features, S=memory_slots
- Edge Memory: O(E · D · S) where E=edges
- Decision Branching: O(E · K) where K=branch_count
- Message Passing: O(E · D + N · D²) per layer
- Memory: O((N+E) · S · D) for stored patterns
BIOLOGICAL INSPIRATION:
- Hippocampal pattern completion and separation
- Cortical associative memory networks
- Synaptic plasticity and connection strength modulation
- Neural circuit motifs with context-dependent routing
- Memory consolidation through repeated activation patterns
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Dict, Optional, Protocol, Tuple, TypedDict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing_extensions import Self
# Configure logging
logger = logging.getLogger(__name__)
if not logger.handlers:
_h = logging.StreamHandler()
_f = logging.Formatter("%(asctime)s | %(name)s | %(levelname)s | %(message)s")
_h.setFormatter(_f)
logger.addHandler(_h)
logger.setLevel(logging.INFO)
#||- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 𓅸 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -||#
class GraphShapeError(ValueError):
"""Raised when provided tensors do not match expected graph shapes."""
class RoutingError(RuntimeError):
"""Raised when routing/branching fails due to numerical or config issues."""
class AuxOut(TypedDict):
"""Auxiliary outputs returned by the model."""
branch_weights: torch.Tensor # (B, N, N, K) simplex on branches
hopfield_node_energy: torch.Tensor # scalar-like
hopfield_edge_energy: torch.Tensor # scalar-like
###########################################################################################################################################
#################################################- - - HELPERS + PROTOCOLS - - -#######################################################
class MergeStrategy(Protocol):
"""Protocol for merging K branch-specific edge weights into an adjacency."""
def __call__(self, base_adj: torch.Tensor, branch_weights: torch.Tensor) -> torch.Tensor:
"""Merge branch weights into an augmented adjacency.
Args:
base_adj: (B, N, N) original adjacency (0/1 or weighted).
branch_weights: (B, N, N, K) simplex per edge over K branches.
Returns:
(B, N, N) merged adjacency.
"""
def barycentric_merge(base_adj: torch.Tensor, branch_weights: torch.Tensor) -> torch.Tensor:
"""Barycentric merge of branches into a single weighted adjacency.
This keeps edge weights in a convex hull of branch hypotheses and the base
adjacency. It's simple, stable, and differentiable.
Mathematical Details:
- Takes mean of branch weights: w̄ = (1/K) Σₖ wₖ
- Applies to existing edges: A'ᵢⱼ = (Aᵢⱼ > 0) · w̄ᵢⱼ
- Preserves graph structure while allowing soft edge weights
Args:
base_adj: (B, N, N) - Original adjacency matrix
branch_weights: (B, N, N, K) - Branch weight simplex
Returns:
(B, N, N) - Merged adjacency with barycentric edge weights
"""
bw = branch_weights.mean(dim=-1) # (B, N, N)
return (base_adj > 0).to(base_adj.dtype) * bw
###########################################################################################################################################
###################################################- - - DECISION GATE - - -############################################################
class HopfieldMemory(nn.Module):
"""Hopfield associative memory with content-based retrieval.
Implements a modern Hopfield network that stores patterns in key-value
memory slots and retrieves them via content-based attention. Uses
temperature-controlled softmax for retrieval sharpness.
Mathematical Framework:
- Keys: K ∈ ℝˢˣᴰ (memory slots × feature dimension)
- Values: V ∈ ℝˢˣᴰ (stored pattern values)
- Query: q ∈ ℝᴰ (input pattern for retrieval)
- Attention: α = softmax(q^T K / √(D·scale))
- Output: o = α^T V (weighted combination of stored patterns)
The "energy" proxy measures retrieval sharpness (low entropy = high energy).
"""
def __init__(self, dim: int, mem_slots: int = 64, key_scale: float = 1.0) -> None:
super().__init__()
self.dim = dim
self.mem_slots = mem_slots
self.key_scale = float(key_scale)
# Initialize memory with small random patterns
self.keys = nn.Parameter(torch.randn(mem_slots, dim) * (1.0 / np.sqrt(dim)))
self.vals = nn.Parameter(torch.randn(mem_slots, dim) * (1.0 / np.sqrt(dim)))
# Learned projections for query and output
self.proj_q = nn.Linear(dim, dim, bias=False)
self.proj_o = nn.Linear(dim, dim, bias=False)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Retrieve patterns from associative memory via content addressing.
Implements the core Hopfield retrieval mechanism:
1. Project input to query space
2. Compute content-based attention over memory keys
3. Retrieve weighted combination of stored values
4. Project output and compute energy proxy
Mathematical Details:
- Query projection: q = W_q · x
- Similarity: sim = q^T K_i for each memory slot i
- Attention: α_i = exp(sim_i/τ) / Σⱼ exp(sim_j/τ)
- Retrieval: r = Σᵢ α_i · V_i
- Output: o = W_o · r
Args:
x: Input tensor (..., D) - query patterns
Returns:
Tuple of:
- out: Retrieved patterns (..., D)
- energy: Scalar energy proxy (higher = more focused retrieval)
"""
# Project input to query space
q = self.proj_q(x) # (..., D)
# Compute attention weights via content-based addressing
scale = np.sqrt(self.dim) * max(self.key_scale, 1e-6)
attn = F.softmax((q @ self.keys.T) / scale, dim=-1) # (..., S)
# Retrieve weighted combination of stored values
retrieved = attn @ self.vals # (..., D)
out = self.proj_o(retrieved)
# Compute energy proxy: negative entropy (higher = more focused)
p = attn.clamp_min(1e-9)
entropy = -(p * p.log()).sum(dim=-1).mean()
energy = -entropy # High focus = low entropy = high energy
return out, energy
###########################################################################################################################################
###############################################- - - HOPFIELD DECISION LAYER - - -#####################################################
class DecisionGate(nn.Module):
"""Learnable branching gate for multi-hypothesis edge routing.
Implements soft decision tree routing where each edge can branch through
K different relation hypotheses. Uses pairwise node features to predict
branch probabilities, enabling context-dependent edge semantics.
Mathematical Framework:
- Edge features: eᵢⱼ = concat(xᵢ, xⱼ) ∈ ℝ²ᴰ
- Branch logits: lᵢⱼ = MLP(eᵢⱼ) ∈ ℝᴷ
- Branch weights: wᵢⱼ = softmax(lᵢⱼ/τ) ∈ Δᴷ⁻¹
Supports both soft routing (training) and hard routing (evaluation)
for computational efficiency and interpretability.
"""
def __init__(self, dim: int, branches: int = 4, temperature: float = 0.7, hard_eval: bool = True) -> None:
super().__init__()
if branches < 1:
raise ValueError("branches must be >= 1")
self.dim = dim
self.K = branches
self.temperature = float(temperature)
self.hard_eval = bool(hard_eval)
# Edge scoring network: maps concatenated node features to branch logits
self.edge_scorer = nn.Sequential(
nn.Linear(2 * dim, dim),
nn.GELU(),
nn.Linear(dim, self.K),
)
@staticmethod
def _pairwise_concat(x: torch.Tensor) -> torch.Tensor:
"""Create pairwise concatenations for all edge combinations.
Generates edge feature representations by concatenating all pairs
of node features, creating a complete edge feature tensor.
Mathematical Details:
- For nodes X = [x₁, x₂, ..., xₙ] ∈ ℝᴺˣᴰ
- Create eᵢⱼ = [xᵢ; xⱼ] ∈ ℝ²ᴰ for all pairs (i,j)
- Result: E ∈ ℝᴺˣᴺˣ²ᴰ
Args:
x: Node features (B, N, D)
Returns:
Edge features (B, N, N, 2D) - concatenated pairwise features
"""
B, N, D = x.shape
xi = x.unsqueeze(2).expand(B, N, N, D) # Source node features
xj = x.unsqueeze(1).expand(B, N, N, D) # Target node features
return torch.cat([xi, xj], dim=-1)
def forward(self, x: torch.Tensor, mask_adj: torch.Tensor) -> torch.Tensor:
"""Compute branch routing probabilities for each edge.
Determines how each edge should route through K relation hypotheses
based on the features of its endpoint nodes. Masked edges receive
zero routing weights.
Mathematical Process:
1. Create pairwise edge features eᵢⱼ = [xᵢ; xⱼ]
2. Compute branch logits lᵢⱼ = EdgeScorer(eᵢⱼ)
3. Apply temperature and softmax: wᵢⱼ = softmax(lᵢⱼ/τ)
4. Mask non-edges and renormalize
Args:
x: Node features (B, N, D)
mask_adj: Edge mask (B, N, N) - 1 for valid edges, 0 for non-edges
Returns:
Branch weights (B, N, N, K) - simplex per edge over K branches
"""
if x.dim() != 3:
raise GraphShapeError("x must be (B,N,D)")
if mask_adj.dim() != 3:
raise GraphShapeError("mask_adj must be (B,N,N)")
B, N, D = x.shape
if mask_adj.shape[:2] != (B, N) or mask_adj.shape[2] != N:
raise GraphShapeError("mask_adj shape mismatch")
# Create pairwise edge features
edge_feats = self._pairwise_concat(x) # (B, N, N, 2D)
# Compute branch logits for each edge
logits = self.edge_scorer(edge_feats) # (B, N, N, K)
# Apply temperature and routing mode
temp = max(self.temperature, 1e-5)
if self.training:
# Soft routing during training
weights = F.softmax(logits / temp, dim=-1)
else:
# Hard or soft routing during evaluation
w = F.softmax(logits / temp, dim=-1)
if self.hard_eval:
# Hard routing: one-hot at maximum branch
idx = w.argmax(dim=-1, keepdim=True)
hard = torch.zeros_like(w).scatter_(-1, idx, 1.0)
weights = hard
else:
weights = w
# Mask out non-edges while preserving simplex property
weights = weights * mask_adj.unsqueeze(-1)
sums = weights.sum(dim=-1, keepdim=True)
weights = torch.where(sums > 0, weights / sums.clamp_min(1e-9), weights)
return weights
###########################################################################################################################################
###############################################- - - HOPFIELD DECISION LAYER - - -#####################################################
class HopfieldDecisionLayer(nn.Module):
"""Graph neural network layer with Hopfield memories and decision branching.
Integrates three key components:
1. Node-level Hopfield memory for pattern completion
2. Edge-level Hopfield memory for relation-specific message encoding
3. Decision gate for soft routing through relation hypotheses
This creates a powerful message-passing framework where both nodes and
edges have associative memory capabilities, and edge semantics can
dynamically adapt based on context.
Mathematical Flow:
1. Node memory retrieval: h'ᵢ = HopfieldNode(hᵢ)
2. Edge message encoding: mᵢⱼ = HopfieldEdge(hᵢ - hⱼ)
3. Decision branching: wᵢⱼ = DecisionGate(hᵢ, hⱼ)
4. Message aggregation: m̄ᵢ = Σⱼ A'ᵢⱼ · mᵢⱼ
5. Node update: hᵢ⁽ˡ⁺¹⁾ = LayerNorm(hᵢ + MLP([h'ᵢ; m̄ᵢ]))
"""
def __init__(
self,
dim: int,
mem_slots_nodes: int = 64,
mem_slots_edges: int = 32,
branches: int = 4,
temperature: float = 0.7,
hard_eval: bool = True,
merge: Optional[MergeStrategy] = None,
) -> None:
super().__init__()
# Core components
self.node_mem = HopfieldMemory(dim, mem_slots_nodes)
self.edge_mem = HopfieldMemory(dim, mem_slots_edges)
self.gate = DecisionGate(dim, branches, temperature, hard_eval)
self.merge = merge or barycentric_merge
# Message processing networks
self.msg_mlp = nn.Sequential(
nn.Linear(dim, dim),
nn.GELU(),
nn.Linear(dim, dim),
)
# Node update network
self.node_mlp = nn.Sequential(
nn.Linear(2 * dim, dim),
nn.GELU(),
nn.Linear(dim, dim),
)
self.norm = nn.LayerNorm(dim)
def forward(self, x: torch.Tensor, A: torch.Tensor) -> Tuple[torch.Tensor, AuxOut]:
"""Execute one layer of memory-augmented graph message passing.
Implements the complete forward pass combining Hopfield memory
retrieval, decision branching, and message aggregation.
Mathematical Algorithm:
1. Retrieve node patterns: h'ᵢ = NodeMemory(hᵢ)
2. Encode edge messages: mᵢⱼ = EdgeMemory(hᵢ - hⱼ)
3. Compute branch routing: wᵢⱼ = DecisionGate(hᵢ, hⱼ)
4. Merge adjacency: A' = MergeStrategy(A, w)
5. Aggregate messages: m̄ᵢ = Σⱼ A'ᵢⱼ · MLP(mᵢⱼ) / deg(i)
6. Update nodes: hᵢ⁽ˡ⁺¹⁾ = LayerNorm(hᵢ + MLP([h'ᵢ; m̄ᵢ]))
Args:
x: Node features (B, N, D)
A: Adjacency matrix (B, N, N) - 0/1 or weighted
Returns:
Tuple of:
- x_next: Updated node features (B, N, D)
- aux: Auxiliary outputs (branch weights, energy values)
"""
if x.dim() != 3:
raise GraphShapeError("x must be (B,N,D)")
if A.dim() != 3:
raise GraphShapeError("A must be (B,N,N)")
B, N, D = x.shape
if A.shape != (B, N, N):
raise GraphShapeError("A shape mismatch with x")
A = A.to(x.dtype)
# Step 1: Node-level Hopfield memory retrieval
node_retrieved, node_energy = self.node_mem(x) # (B,N,D), scalar
# Step 2: Edge-level message encoding with Hopfield memory
# Use anti-symmetric edge representation (hᵢ - hⱼ) for directional messages
xi = x.unsqueeze(2).expand(B, N, N, D)
xj = x.unsqueeze(1).expand(B, N, N, D)
edge_repr = xi - xj # Anti-symmetric: eᵢⱼ = -eⱼᵢ
edge_mem_out, edge_energy = self.edge_mem(edge_repr) # (B,N,N,D), scalar
# Step 3: Decision branching for edge routing
branch_w = self.gate(x, (A > 0).to(A.dtype)) # (B,N,N,K)
A_aug = self.merge(A, branch_w).clamp_min(0.0)
# Step 4: Message passing with degree normalization
msg = self.msg_mlp(edge_mem_out) # (B,N,N,D)
agg = torch.einsum("bij,bijd->bid", A_aug, msg) # (B,N,D)
# Degree normalization for stability across graph sizes
deg = A_aug.sum(dim=-1, keepdim=True).clamp_min(1e-9)
msg_norm = agg / deg
# Step 5: Node update with residual connection
x_cat = torch.cat([x + node_retrieved, msg_norm], dim=-1)
x_update = self.node_mlp(x_cat)
x_next = self.norm(x + x_update)
# Package auxiliary outputs
aux: AuxOut = {
"branch_weights": branch_w,
"hopfield_node_energy": node_energy.detach().clone(),
"hopfield_edge_energy": edge_energy.detach().clone(),
}
return x_next, aux
###########################################################################################################################################
###############################################- - - HOPFIELD DECISION GNN - - -#######################################################
@dataclass
class HopfieldDecisionGNNConfig:
"""Configuration for the complete Hopfield Decision GNN model."""
dim: int # Feature dimension
layers: int = 3 # Number of GNN layers
mem_slots_nodes: int = 64 # Node memory capacity
mem_slots_edges: int = 32 # Edge memory capacity
branches: int = 4 # Decision branches per edge
temperature: float = 0.7 # Branching temperature
hard_eval: bool = True # Use hard routing in eval mode
class HopfieldDecisionGNN(nn.Module):
"""Complete Hopfield Decision GNN with stacked layers and global adaptation.
Implements a multi-layer graph neural network where each layer combines:
- Hopfield associative memories for nodes and edges
- Decision-tree-like branching for edge relation types
- Adaptive message passing with memory retrieval
The model learns to store and retrieve graph patterns while dynamically
routing messages through different relation hypotheses based on context.
Architecture:
- Multiple HopfieldDecisionLayers stacked sequentially
- Global adaptation mechanism for model-wide learning
- Residual readout combining raw and processed representations
"""
def __init__(self, cfg: HopfieldDecisionGNNConfig) -> None:
super().__init__()
self.cfg = cfg
# Stack of Hopfield decision layers
self.layers = nn.ModuleList([
HopfieldDecisionLayer(
dim=cfg.dim,
mem_slots_nodes=cfg.mem_slots_nodes,
mem_slots_edges=cfg.mem_slots_edges,
branches=cfg.branches,
temperature=cfg.temperature,
hard_eval=cfg.hard_eval,
)
for _ in range(cfg.layers)
])
# Global adaptation and output processing
self.readout = nn.Sequential(
nn.Linear(cfg.dim, cfg.dim),
nn.GELU(),
nn.Linear(cfg.dim, cfg.dim),
)
def forward(self, x: torch.Tensor, A: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Forward pass through the complete stacked model.
Processes graphs through multiple layers of memory-augmented message
passing, collecting energy statistics and routing information across
all layers for analysis and optimization.
Mathematical Flow:
1. Initialize with input node features
2. For each layer l = 1..L:
- Apply Hopfield memory retrieval
- Compute decision branching weights
- Perform message passing with routing
- Update node representations
3. Apply final readout with residual connection
4. Aggregate auxiliary statistics across layers
Args:
x: Node features (B, N, D)
A: Adjacency matrix (B, N, N)
Returns:
Tuple of:
- y: Final node representations (B, N, D)
- aux_all: Aggregated auxiliary information across layers
"""
energies_node = []
energies_edge = []
last_br = None
# Process through all layers
for i, layer in enumerate(self.layers):
x, aux = layer(x, A)
# Collect energy statistics
energies_node.append(aux["hopfield_node_energy"])
energies_edge.append(aux["hopfield_edge_energy"])
last_br = aux["branch_weights"]
# Debug logging
logger.debug(
"Layer %d: node_energy=%.5f edge_energy=%.5f",
i,
float(aux["hopfield_node_energy"]),
float(aux["hopfield_edge_energy"]),
)
# Final readout with residual connection
y = self.readout(x) + x
# Aggregate auxiliary outputs
aux_all: Dict[str, torch.Tensor] = {
"node_energy_mean": torch.stack(energies_node).mean(),
"edge_energy_mean": torch.stack(energies_edge).mean(),
}
if last_br is not None:
aux_all["branch_weights_last"] = last_br
return y, aux_all
###########################################################################################################################################
###############################################- - - HOPFIELD DEMO + TEST - - -########################################################
def _make_batch(B: int, N: int, D: int, p_edge: float = 0.5) -> Tuple[torch.Tensor, torch.Tensor]:
"""Generate random batch of graphs for testing."""
torch.manual_seed(42)
x = torch.randn(B, N, D)
A = (torch.rand(B, N, N) < p_edge).float()
# Remove self-loops for variety
eye = torch.eye(N).unsqueeze(0)
A = A * (1 - eye)
return x, A
def test_shapes_and_aux() -> None:
"""Test basic functionality and output shapes."""
cfg = HopfieldDecisionGNNConfig(dim=32, layers=2, branches=3)
model = HopfieldDecisionGNN(cfg)
x, A = _make_batch(B=2, N=5, D=32)
y, aux = model(x, A)
assert y.shape == x.shape, f"y shape {y.shape} != x {x.shape}"
assert "node_energy_mean" in aux and "edge_energy_mean" in aux
assert "branch_weights_last" in aux
bw = aux["branch_weights_last"]
assert bw.shape == (2, 5, 5, 3)
# Check simplex property on existing edges
s = bw.sum(dim=-1)
masked = (A == 0)
assert torch.allclose(s[~masked], torch.ones_like(s[~masked]), atol=1e-5)
print("[PASS] shapes_and_aux")
def test_gate_eval_hard_routing() -> None:
"""Test hard routing behavior during evaluation."""
cfg = HopfieldDecisionGNNConfig(dim=16, layers=1, branches=4, hard_eval=True)
gate = DecisionGate(dim=16, branches=4, temperature=0.5, hard_eval=True)
x, A = _make_batch(B=1, N=4, D=16)
gate.eval()
with torch.no_grad():
w = gate(x, A)
assert (w.sum(dim=-1) - (A > 0).float()).abs().max() < 1e-5
# Check one-hot property on edges
on_edges = (A[0] > 0)
if on_edges.any():
sub = w[0][on_edges]
assert torch.allclose(sub.max(dim=-1).values, torch.ones_like(sub[..., 0]))
print("[PASS] gate_eval_hard_routing")
def test_gradient_flow() -> None:
"""Test that gradients flow through the model."""
cfg = HopfieldDecisionGNNConfig(dim=24, layers=3)
model = HopfieldDecisionGNN(cfg)
x, A = _make_batch(B=3, N=6, D=24)
y, aux = model(x, A)
# Create loss that depends on all auxiliary outputs
loss = (y ** 2).mean() + aux["node_energy_mean"] * 0.01 + aux["edge_energy_mean"] * 0.01
loss.backward()
# Check that some parameters received gradients
grads = [p.grad is not None and p.grad.abs().sum().item() > 0 for p in model.parameters()]
assert any(grads), "No gradients found"
print("[PASS] gradient_flow")
def test_batching_invariance() -> None:
"""Test that batching doesn't affect individual graph processing."""
cfg = HopfieldDecisionGNNConfig(dim=12, layers=2)
model = HopfieldDecisionGNN(cfg)
x1, A1 = _make_batch(B=1, N=5, D=12)
x2, A2 = _make_batch(B=1, N=5, D=12)
# Process individually
y1, _ = model(x1, A1)
y2, _ = model(x2, A2)
# Process as batch
y_cat, _ = model(torch.cat([x1, x2], dim=0), torch.cat([A1, A2], dim=0))
assert torch.allclose(y1, y_cat[:1], atol=1e-5)
assert torch.allclose(y2, y_cat[1:], atol=1e-5)
print("[PASS] batching_invariance")
def test_shape_errors() -> None:
"""Test that appropriate errors are raised for invalid inputs."""
cfg = HopfieldDecisionGNNConfig(dim=8, layers=1)
model = HopfieldDecisionGNN(cfg)
x, A = _make_batch(B=2, N=4, D=8)
# Test wrong rank for x
try:
_ = model(x[0], A)
raise AssertionError("Expected GraphShapeError not raised")
except GraphShapeError:
pass
# Test wrong rank for A
try:
_ = model(x, A[0])
raise AssertionError("Expected GraphShapeError not raised")
except GraphShapeError:
pass
# Test mismatched dimensions
try:
_ = model(x, A[:, :3, :3])
raise AssertionError("Expected GraphShapeError not raised")
except GraphShapeError:
pass
print("[PASS] shape_errors")
def test_hopfield_decision_graph():
"""Comprehensive test of Hopfield Decision Graph functionality."""
print("Testing Hopfield Decision Graph - Memory-Augmented Graph Neural Networks")
print("=" * 85)
# Create Hopfield Decision GNN
cfg = HopfieldDecisionGNNConfig(
dim=64,
layers=4,
mem_slots_nodes=32,
mem_slots_edges=16,
branches=3,
temperature=0.8,
hard_eval=True
)
model = HopfieldDecisionGNN(cfg)
print(f"Created Hopfield Decision GNN:")
print(f" - Feature dimension: {cfg.dim}")
print(f" - Number of layers: {cfg.layers}")
print(f" - Node memory slots: {cfg.mem_slots_nodes}")
print(f" - Edge memory slots: {cfg.mem_slots_edges}")
print(f" - Decision branches: {cfg.branches}")
# Count parameters
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f" - Total parameters: {total_params:,}")
# Test with sample graph data
batch_size, num_nodes = 8, 12
x, A = _make_batch(batch_size, num_nodes, cfg.dim, p_edge=0.3)
print(f"\nTesting with graphs:")
print(f" - Batch size: {batch_size}")
print(f" - Nodes per graph: {num_nodes}")
print(f" - Edge density: ~30%")
print(f" - Total edges: {A.sum().item():.0f}")
# Forward pass
print(f"\n Executing forward pass...")
y, aux = model(x, A)
print(f"Forward pass results:")
print(f" - Output shape: {y.shape}")
print(f" - Node energy: {aux['node_energy_mean']:.4f}")
print(f" - Edge energy: {aux['edge_energy_mean']:.4f}")
# Analyze decision branching
if 'branch_weights_last' in aux:
branch_weights = aux['branch_weights_last']
print(f"\nDecision branching analysis:")
print(f" - Branch weights shape: {branch_weights.shape}")
# Check branching diversity
branch_entropy = -(branch_weights * torch.log(branch_weights + 1e-8)).sum(dim=-1)
avg_entropy = branch_entropy[A > 0].mean().item()
max_entropy = math.log(cfg.branches)
branching_diversity = avg_entropy / max_entropy
print(f" - Average branching entropy: {avg_entropy:.3f}")
print(f" - Branching diversity: {branching_diversity:.1%}")
# Most common branch assignments
most_common_branches = branch_weights.argmax(dim=-1)
for b in range(cfg.branches):
count = (most_common_branches == b).sum().item()
total_edges = (A > 0).sum().item()
pct = count / max(total_edges, 1) * 100
print(f" - Branch {b}: {count} edges ({pct:.1f}%)")
# Test memory retrieval patterns
print(f"\n Testing memory components...")
# Test individual Hopfield memory
test_memory = HopfieldMemory(cfg.dim, mem_slots=16)
test_input = torch.randn(4, cfg.dim)
retrieved, energy = test_memory(test_input)
print(f" - Memory retrieval shape: {retrieved.shape}")
print(f" - Memory energy: {energy:.4f}")
# Test decision gate
test_gate = DecisionGate(cfg.dim, branches=cfg.branches)
test_nodes = torch.randn(2, 6, cfg.dim)
test_adj = torch.randint(0, 2, (2, 6, 6)).float()
gate_weights = test_gate(test_nodes, test_adj)
print(f" - Gate output shape: {gate_weights.shape}")
print(f" - Gate simplex check: {torch.allclose(gate_weights.sum(-1), test_adj, atol=1e-5)}")
# Test different graph structures
print(f"\n Testing structural adaptivity...")
# Dense graph
dense_A = torch.ones(1, num_nodes, num_nodes) - torch.eye(num_nodes).unsqueeze(0)
dense_x = torch.randn(1, num_nodes, cfg.dim)
dense_y, dense_aux = model(dense_x, dense_A)
# Sparse graph
sparse_A = torch.zeros(1, num_nodes, num_nodes)
sparse_A[0, 0, 1] = sparse_A[0, 1, 2] = sparse_A[0, 2, 0] = 1 # Simple cycle
sparse_x = torch.randn(1, num_nodes, cfg.dim)
sparse_y, sparse_aux = model(sparse_x, sparse_A)
print(f" - Dense graph node energy: {dense_aux['node_energy_mean']:.4f}")
print(f" - Sparse graph node energy: {sparse_aux['node_energy_mean']:.4f}")
print(f" - Dense graph edge energy: {dense_aux['edge_energy_mean']:.4f}")
print(f" - Sparse graph edge energy: {sparse_aux['edge_energy_mean']:.4f}")
# Test evaluation mode (hard routing)
print(f"\n Testing evaluation mode...")
model.eval()
with torch.no_grad():
eval_y, eval_aux = model(x[:2], A[:2])
if 'branch_weights_last' in eval_aux:
eval_weights = eval_aux['branch_weights_last']
# Check for one-hot vectors (hard routing)
max_vals = eval_weights.max(dim=-1)[0]
edges_mask = (A[:2] > 0)
hard_routing_check = torch.allclose(max_vals[edges_mask], torch.ones_like(max_vals[edges_mask]))
print(f" - Hard routing active: {hard_routing_check}")
model.train()
print(f"\n Hopfield Decision Graph test completed!")
print("✓ Dual memory architecture (nodes + edges)")
print("✓ Decision-tree edge routing with soft branching")
print("✓ Energy-based associative memory retrieval")
print("✓ Hard/soft routing modes for training/evaluation")
print("✓ Memory-augmented graph message passing")
print("✓ Adaptive edge semantics based on node context")
return True
def memory_pattern_demo():
"""Demonstrate memory pattern storage and retrieval."""
print("\n" + "="*60)
print(" MEMORY PATTERN DEMONSTRATION")
print("="*60)
# Create simple memory for clear demonstration
memory = HopfieldMemory(dim=8, mem_slots=4)
# Store some patterns manually
with torch.no_grad():
memory.keys[0] = torch.tensor([1, 0, 1, 0, 1, 0, 1, 0], dtype=torch.float32)
memory.keys[1] = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1], dtype=torch.float32)
memory.keys[2] = torch.tensor([1, 1, 0, 0, 1, 1, 0, 0], dtype=torch.float32)
memory.keys[3] = torch.tensor([0, 0, 1, 1, 0, 0, 1, 1], dtype=torch.float32)
memory.vals.copy_(memory.keys) # Values same as keys for simplicity
# Test pattern completion
print("Testing pattern completion:")
test_patterns = [
torch.tensor([1, 0, 1, 0, 0.5, 0, 0.8, 0], dtype=torch.float32), # Noisy pattern 0
torch.tensor([0, 1, 0, 1, 0.2, 1, 0, 0.9], dtype=torch.float32), # Noisy pattern 1
torch.tensor([1, 1, 0, 0, 0.7, 0.8, 0, 0], dtype=torch.float32), # Noisy pattern 2
]
for i, noisy_pattern in enumerate(test_patterns):
retrieved, energy = memory(noisy_pattern.unsqueeze(0))
retrieved = retrieved.squeeze(0)
print(f"\n Test {i+1}:")
print(f" Input: {noisy_pattern.numpy()}")
print(f" Retrieved: {retrieved.detach().numpy()}")
print(f" Energy: {energy.item():.3f}")
print("\n Memory demonstrates pattern completion and associative recall!")
print(" Noisy inputs are cleaned up to stored prototype patterns")
if __name__ == "__main__":
torch.set_float32_matmul_precision("high")
test_shapes_and_aux()
test_gate_eval_hard_routing()
test_gradient_flow()
test_batching_invariance()
test_shape_errors()
test_hopfield_decision_graph()
memory_pattern_demo()
print("\nAll tests passed")
###########################################################################################################################################
###########################################################################################################################################