flownet / utils /math_utils.py
Ashu9675's picture
Add FlowNet: Post-Transformer Architecture
d4fff7c
Raw
History Blame Contribute Delete
11.4 kB
"""
Mathematical utilities for FlowNet.
Wave mechanics, topology helpers, and novel computation primitives.
"""
import torch
import torch.nn.functional as F
import numpy as np
from typing import Tuple, List, Optional
# =============================================================================
# Wave Mathematics
# =============================================================================
def complex_wave(amplitude: torch.Tensor, phase: torch.Tensor) -> torch.Tensor:
"""Create a complex wave from amplitude and phase.
wave = amplitude * exp(i * phase)
This is the fundamental representation for wave-based processing.
Real and imaginary parts encode different aspects of meaning.
"""
return amplitude * torch.exp(1j * phase)
def wave_interference(wave1: torch.Tensor, wave2: torch.Tensor) -> torch.Tensor:
"""Compute interference between two waves.
Constructive interference (in-phase) → amplified signal
Destructive interference (out-of-phase) → cancelled signal
Returns the interference pattern as a real tensor.
"""
interference = wave1 + wave2
return torch.abs(interference)
def phase_coherence(phases: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""Measure how synchronized a set of phases are.
Returns 1.0 if all phases are identical, 0.0 if uniformly distributed.
This measures "semantic alignment" between particles.
Uses the order parameter from Kuramoto model:
R = |1/N * sum(exp(i * theta_j))|
"""
complex_phases = torch.exp(1j * phases)
order_param = torch.abs(torch.mean(complex_phases, dim=dim))
return order_param
def resonance_detect(
wave_field: torch.Tensor,
min_amplitude: float = 0.1,
coherence_threshold: float = 0.7
) -> List[torch.Tensor]:
"""Detect standing wave patterns (resonances) in a wave field.
Resonances represent stable meaning patterns — when multiple particles
oscillate in sync, their shared meaning is amplified.
Returns list of resonance regions (tensors of particle indices).
"""
amplitudes = torch.abs(wave_field)
phases = torch.angle(wave_field)
# Find high-amplitude regions
active_mask = amplitudes > min_amplitude
# Group by phase similarity (particles that are "in sync")
n = phases.shape[0]
visited = torch.zeros(n, dtype=torch.bool)
resonances = []
phase_diff_matrix = torch.abs(
phases.unsqueeze(0) - phases.unsqueeze(1)
)
# Wrap to [0, pi]
phase_diff_matrix = torch.min(phase_diff_matrix, 2 * np.pi - phase_diff_matrix)
for i in range(n):
if visited[i] or not active_mask[i]:
continue
# Find all particles in phase with particle i
in_phase = (phase_diff_matrix[i] < coherence_threshold) & active_mask
resonances.append(torch.where(in_phase)[0])
visited[in_phase] = True
return resonances
def wave_propagate(
source_amplitude: torch.Tensor,
source_phase: torch.Tensor,
distance: torch.Tensor,
decay_rate: float = 0.1,
wavelength: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Propagate a wave through space with distance-based decay.
Waves carry information across the particle field without needing
direct pairwise attention. Information travels at wave speed.
Returns (amplitude_at_distance, phase_at_distance).
"""
# Amplitude decays with distance (inverse square-ish)
received_amplitude = source_amplitude / (1.0 + decay_rate * distance)
# Phase shifts with distance (propagation delay)
received_phase = source_phase + 2 * np.pi * distance / wavelength
return received_amplitude, received_phase
# =============================================================================
# Topological Mathematics
# =============================================================================
def pairwise_distances(positions: torch.Tensor) -> torch.Tensor:
"""Compute pairwise Euclidean distances between particle positions."""
diff = positions.unsqueeze(0) - positions.unsqueeze(1)
return torch.norm(diff, dim=-1)
def build_adjacency(
positions: torch.Tensor,
threshold: float = 1.0,
use_bonds: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Build adjacency matrix from particle positions.
If use_bonds is provided, uses bond strengths instead of distance.
"""
if use_bonds is not None:
return use_bonds
distances = pairwise_distances(positions)
adjacency = (distances < threshold).float()
# Remove self-loops
adjacency.fill_diagonal_(0)
return adjacency
def compute_betti_numbers(adjacency: torch.Tensor, max_dim: int = 2) -> List[int]:
"""Compute Betti numbers (topological features) from adjacency.
β₀ = number of connected components
β₁ = number of loops/holes
β₂ = number of voids
This is a simplified computation for the prototype.
Full persistent homology would use giotto-tda or similar.
"""
n = adjacency.shape[0]
# β₀: connected components via BFS
visited = torch.zeros(n, dtype=torch.bool)
beta_0 = 0
for i in range(n):
if visited[i]:
continue
beta_0 += 1
queue = [i]
while queue:
node = queue.pop(0)
if visited[node]:
continue
visited[node] = True
neighbors = torch.where(adjacency[node] > 0)[0]
queue.extend(neighbors[~visited[neighbors]].tolist())
# β₁: approximate cycle count (edges - vertices + components)
edges = torch.sum(adjacency > 0).item() / 2
beta_1 = max(0, int(edges - n + beta_0))
# β₂: placeholder (would need simplicial complex computation)
beta_2 = 0
return [beta_0, beta_1, beta_2]
def topological_signature(adjacency: torch.Tensor) -> torch.Tensor:
"""Compute a topological signature vector for memory comparison.
This is a fixed-size representation of the topology that can be
used for similarity search in topological memory.
"""
betti = compute_betti_numbers(adjacency)
# Additional topological features
n = adjacency.shape[0]
degrees = torch.sum(adjacency > 0, dim=-1).float()
features = torch.tensor([
float(betti[0]), # connected components
float(betti[1]), # loops
float(betti[2]), # voids
n, # size
torch.mean(degrees), # average degree
torch.std(degrees), # degree variance
torch.max(degrees), # max degree
float(torch.sum(adjacency > 0).item()) / 2, # edge count
], dtype=torch.float32)
return features
# =============================================================================
# Dynamical Systems
# =============================================================================
def kuramoto_step(
phases: torch.Tensor,
frequencies: torch.Tensor,
coupling: torch.Tensor,
dt: float = 0.01
) -> torch.Tensor:
"""One step of the Kuramoto model for phase synchronization.
dθᵢ/dt = ωᵢ + Σⱼ Kᵢⱼ sin(θⱼ - θᵢ)
This governs how particles synchronize their phases —
strongly coupled particles synchronize, weakly coupled drift apart.
Synchronization = semantic alignment.
"""
n = phases.shape[0]
phase_diff = phases.unsqueeze(0) - phases.unsqueeze(1) # [n, n]
# Coupling force (sin of phase difference)
sync_force = torch.sum(coupling * torch.sin(-phase_diff), dim=-1) # [n]
# Update phases
new_phases = phases + dt * (frequencies + sync_force)
# Wrap to [0, 2π]
new_phases = new_phases % (2 * np.pi)
return new_phases
def lennard_jones_force(
positions: torch.Tensor,
epsilon: float = 1.0,
sigma: float = 1.0,
cutoff: float = 3.0
) -> torch.Tensor:
"""Compute Lennard-Jones-like forces between particles.
Attraction at medium range, repulsion at close range.
This creates natural "preferred distances" between particles —
semantically related particles settle at stable distances.
F = 24ε * [(σ/r)⁶ * (2(σ/r)⁶ - 1)] / r
Returns force vectors for each particle [n, dim].
"""
n = positions.shape[0]
diff = positions.unsqueeze(0) - positions.unsqueeze(1) # [n, n, dim]
distances = torch.norm(diff, dim=-1, keepdim=True) # [n, n, 1]
distances = torch.clamp(distances, min=0.1) # avoid division by zero
# Apply cutoff
mask = (distances.squeeze(-1) < cutoff).float()
mask.fill_diagonal_(0)
# LJ force magnitude
sr6 = (sigma / distances) ** 6
force_mag = 24 * epsilon * sr6 * (2 * sr6 - 1) / distances
# Force vectors (direction from j to i)
direction = diff / distances
forces = force_mag * direction * mask.unsqueeze(-1)
return forces.sum(dim=1) # [n, dim]
def simulated_annealing_step(
state: torch.Tensor,
energy_fn,
temperature: float,
perturbation_scale: float = 0.1
) -> Tuple[torch.Tensor, float]:
"""One step of simulated annealing for energy minimization.
Used by the Consistency Field to find low-energy (true) states.
High temperature → accept worse states (explore)
Low temperature → only accept better states (exploit)
"""
current_energy = energy_fn(state)
# Propose perturbation
perturbation = perturbation_scale * torch.randn_like(state)
proposed_state = state + perturbation
proposed_energy = energy_fn(proposed_state)
# Metropolis criterion
delta_e = proposed_energy - current_energy
if delta_e < 0:
# Always accept improvements
return proposed_state, proposed_energy
else:
# Accept worse states with probability exp(-ΔE/T)
acceptance_prob = torch.exp(-delta_e / max(temperature, 1e-8))
if torch.rand(1).item() < acceptance_prob.item():
return proposed_state, proposed_energy
else:
return state, current_energy
# =============================================================================
# Activation & Normalization (Novel)
# =============================================================================
def phase_activation(x: torch.Tensor) -> torch.Tensor:
"""Activation function that maps to phase space [0, 2π].
Unlike ReLU/GELU, this preserves periodic structure.
"""
return torch.sigmoid(x) * 2 * np.pi
def amplitude_activation(x: torch.Tensor) -> torch.Tensor:
"""Activation for amplitudes — always positive, bounded.
Uses softplus to ensure positivity with smooth gradients.
"""
return F.softplus(x)
def topological_normalize(x: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""Normalize while preserving topological structure.
Instead of LayerNorm (which destroys phase information),
normalize amplitudes independently of phases.
"""
# Split into magnitude and direction
magnitude = torch.norm(x, dim=dim, keepdim=True)
direction = x / (magnitude + 1e-8)
# Normalize magnitude to unit sphere
normalized_magnitude = F.softmax(magnitude, dim=dim)
return direction * normalized_magnitude